diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index f0d029f804..1028d1a089 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -690,13 +690,15 @@ jobs: - name: Pack artifacts id: pack_artifacts run: | - tar -czvf llama-${{ steps.tag.outputs.name }}-xcframework.tar.gz -C build-apple llama.xcframework + # Zip file is required for Swift Package Manager, which does not support tar.gz for binary targets. + # For more details, see https://developer.apple.com/documentation/xcode/distributing-binary-frameworks-as-swift-packages + zip -r -y llama-${{ steps.tag.outputs.name }}-xcframework.zip build-apple/llama.xcframework - name: Upload artifacts uses: actions/upload-artifact@v4 with: - path: llama-${{ steps.tag.outputs.name }}-xcframework.tar.gz - name: llama-${{ steps.tag.outputs.name }}-xcframework.tar.gz + path: llama-${{ steps.tag.outputs.name }}-xcframework.zip + name: llama-${{ steps.tag.outputs.name }}-xcframework.zip openEuler-cann: @@ -865,7 +867,7 @@ jobs: **macOS/iOS:** - [macOS Apple Silicon (arm64)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz) - [macOS Intel (x64)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz) - - [iOS XCFramework](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-xcframework.tar.gz) + - [iOS XCFramework](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-xcframework.zip) **Linux:** - [Ubuntu x64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-x64.tar.gz) diff --git a/common/arg.cpp b/common/arg.cpp index da887f6cfe..e1411b679c 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2094,7 +2094,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "override tensor buffer type", [](common_params & params, const std::string & value) { parse_tensor_buffer_overrides(value, params.tensor_buft_overrides); } - )); + ).set_env("LLAMA_ARG_OVERRIDE_TENSOR")); add_opt(common_arg( {"-otd", "--override-tensor-draft"}, "=,...", "override tensor buffer type for draft model", [](common_params & params, const std::string & value) { diff --git a/common/common.cpp b/common/common.cpp index 9792c0b6a6..34ed3c16c8 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1078,6 +1078,8 @@ struct common_init_result::impl { impl() = default; ~impl() = default; + // note: the order in which model, context, etc. are declared matters because their destructors will be called bottom-to-top + llama_model_ptr model; llama_context_ptr context; diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 432be59946..16c5acf346 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -141,16 +141,24 @@ class ModelBase: self.model_name = model_name self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py - # Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type + # Apply heuristics to figure out typical tensor encoding based on first tensor's dtype + # NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie. if self.ftype == gguf.LlamaFileType.GUESSED: - # NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie. - _, first_tensor = next(self.get_tensors()) - if first_tensor.dtype == torch.float16: - logger.info(f"choosing --outtype f16 from first tensor type ({first_tensor.dtype})") - self.ftype = gguf.LlamaFileType.MOSTLY_F16 + for _, tensor in self.get_tensors(): + if tensor.dim() < 2: + continue + + if tensor.dtype == torch.bfloat16: + self.ftype = gguf.LlamaFileType.MOSTLY_BF16 + logger.info("heuristics detected bfloat16 tensor dtype, setting --outtype bf16") + break + elif tensor.dtype == torch.float16: + self.ftype = gguf.LlamaFileType.MOSTLY_F16 + logger.info("heuristics detected float16 tensor dtype, setting --outtype f16") + break else: - logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})") - self.ftype = gguf.LlamaFileType.MOSTLY_BF16 + self.ftype = gguf.LlamaFileType.MOSTLY_F16 + logger.info("heuristics unable to detect tensor dtype, defaulting to --outtype f16") self.dequant_model() @@ -1204,6 +1212,9 @@ class TextModel(ModelBase): if chkhsh == "a1e163ecab2e718a4c829d1148b6e86824ec36163bb71941c3dca9cd5ac25756": # ref: https://huggingface.co/JetBrains/Mellum-4b-base res = "mellum" + if chkhsh == "a0b64b4385f123663873756336c085744376d015ff328bb1d901598f63c44152": + # ref: https://huggingface.co/answerdotai/ModernBERT-base + res = "modern-bert" if chkhsh == "49fc0303c9e0d2c2c565c510f64b2d9b271276acdcdadff733249eda9f7d59df": # ref: https://huggingface.co/arcee-ai/Trinity-Tokenizer res = "afmoe" @@ -9991,6 +10002,36 @@ class SmallThinkerModel(TextModel): raise ValueError(f"Unprocessed experts: {experts}") +@ModelBase.register("ModernBertModel", "ModernBertForMaskedLM", "ModernBertForSequenceClassification") +class ModernBertModel(BertModel): + model_arch = gguf.MODEL_ARCH.MODERN_BERT + + def set_vocab(self): + self.gguf_writer.add_add_bos_token(True) + self.gguf_writer.add_add_eos_token(True) + self.gguf_writer.add_add_sep_token(True) + self._set_vocab_gpt2() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_sliding_window(self.hparams["local_attention"]) + if (sliding_window_pattern := self.hparams.get("global_attn_every_n_layers")) is not None: + self.gguf_writer.add_sliding_window_pattern(sliding_window_pattern) + self.gguf_writer.add_rope_freq_base_swa(self.rope_parameters.get("sliding_attention", {"rope_theta": self.hparams.get("local_rope_theta")})["rope_theta"]) + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # these layers act as MLM head, so we don't need them + if name.startswith("decoder."): + return [] + + if name.startswith("model."): + name = name[6:] + + return super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("ApertusForCausalLM") class ApertusModel(LlamaModel): model_arch = gguf.MODEL_ARCH.APERTUS @@ -10557,8 +10598,8 @@ def parse_args() -> argparse.Namespace: help="path to write to; default: based on input. {ftype} will be replaced by the outtype.", ) parser.add_argument( - "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "auto"], default="f16", - help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, tq1_0 or tq2_0 for ternary, and auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", + "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "auto"], default="auto", + help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, tq1_0 or tq2_0 for ternary, and auto for the highest-fidelity 16-bit float type", ) parser.add_argument( "--bigendian", action="store_true", diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 5e8456a7ea..4378378309 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -139,6 +139,7 @@ models = [ {"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"}, {"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", }, {"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", }, + {"name": "modern-bert", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/answerdotai/ModernBERT-base", }, {"name": "afmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/arcee-ai/Trinity-Tokenizer", }, {"name": "bailingmoe2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-base-2.0", }, {"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", }, diff --git a/docs/backend/hexagon/README.md b/docs/backend/hexagon/README.md index 85f136ef9e..00ec3a7e71 100644 --- a/docs/backend/hexagon/README.md +++ b/docs/backend/hexagon/README.md @@ -106,7 +106,7 @@ Here are some examples of running various llama.cpp tools via ADB. Simple question for Llama-3.2-1B ``` -~/src/llama.cpp$ M=Llama-3.2-1B-Instruct-Q4_0.gguf D=HTP0 ./scripts/snapdragon/adb/run-cli.sh -no-cnv -p "what is the most popular cookie in the world?" +~/src/llama.cpp$ M=Llama-3.2-1B-Instruct-Q4_0.gguf D=HTP0 ./scripts/snapdragon/adb/run-completion.sh -p "what is the most popular cookie in the world?" ... ggml-hex: Hexagon backend (experimental) : allocating new registry : ndev 1 ggml-hex: Hexagon Arch version v79 @@ -136,7 +136,7 @@ llama_memory_breakdown_print: | - HTP0-REPACK | 504 = Summary request for OLMoE-1B-7B. This is a large model that requires two HTP sessions/devices ``` -~/src/llama.cpp$ M=OLMoE-1B-7B-0125-Instruct-Q4_0.gguf NDEV=2 D=HTP0,HTP1 ./scripts/snapdragon/adb/run-cli.sh -f surfing.txt -no-cnv +~/src/llama.cpp$ M=OLMoE-1B-7B-0125-Instruct-Q4_0.gguf NDEV=2 D=HTP0,HTP1 ./scripts/snapdragon/adb/run-completion.sh -f surfing.txt ... ggml-hex: Hexagon backend (experimental) : allocating new registry : ndev 1 ggml-hex: Hexagon Arch version v81 @@ -234,6 +234,6 @@ build: 6a8cf8914 (6733) Examples: - `GGML_HEXAGON_OPMASK=0x1 llama-cli ...` - Ops are enqueued but NPU-side processing is stubbed out - `GGML_HEXAGON_OPMASK=0x3 llama-cli ...` - NPU performs dynamic quantization and skips the rest - `GGML_HEXAGON_OPMASK=0x7 llama-cli ...` - Full queuing and processing of Ops (default) + `GGML_HEXAGON_OPMASK=0x1 llama-completion ...` - Ops are enqueued but NPU-side processing is stubbed out + `GGML_HEXAGON_OPMASK=0x3 llama-completion ...` - NPU performs dynamic quantization and skips the rest + `GGML_HEXAGON_OPMASK=0x7 llama-completion ...` - Full queuing and processing of Ops (default) diff --git a/docs/backend/hexagon/developer.md b/docs/backend/hexagon/developer.md index 200a7aabc0..fc4d160e93 100644 --- a/docs/backend/hexagon/developer.md +++ b/docs/backend/hexagon/developer.md @@ -49,7 +49,7 @@ Each Hexagon device behaves like a GPU from the offload and model splitting pers Here is an example of running GPT-OSS-20B model on a newer Snapdragon device with 16GB of DDR. ``` -M=gpt-oss-20b-Q4_0.gguf NDEV=4 D=HTP0,HTP1,HTP2,HTP3 P=surfing.txt scripts/snapdragon/adb/run-cli.sh -no-cnv -f surfing.txt -n 32 +M=gpt-oss-20b-Q4_0.gguf NDEV=4 D=HTP0,HTP1,HTP2,HTP3 P=surfing.txt scripts/snapdragon/adb/run-completion.sh -f surfing.txt -n 32 ... LD_LIBRARY_PATH=/data/local/tmp/llama.cpp/lib ADSP_LIBRARY_PATH=/data/local/tmp/llama.cpp/lib diff --git a/examples/gen-docs/gen-docs.cpp b/examples/gen-docs/gen-docs.cpp index dc76c4cf53..0aa33e8245 100644 --- a/examples/gen-docs/gen-docs.cpp +++ b/examples/gen-docs/gen-docs.cpp @@ -2,57 +2,74 @@ #include "common.h" #include +#include #include // Export usage message (-h) to markdown format +// Automatically update the markdown docs -static void write_table_header(std::ofstream & file) { - file << "| Argument | Explanation |\n"; - file << "| -------- | ----------- |\n"; +#define HELP_START_MARKER "" +#define HELP_END_MARKER "" +#define NOTE_MESSAGE "" + +struct md_file { + llama_example ex; + std::string fname; + std::string specific_section_header; +}; + +std::vector md_files = { + {LLAMA_EXAMPLE_CLI, "tools/cli/README.md", "CLI-specific params"}, + {LLAMA_EXAMPLE_COMPLETION, "tools/completion/README.md", "Completion-specific params"}, + {LLAMA_EXAMPLE_SERVER, "tools/server/README.md", "Server-specific params"}, +}; + +static void write_table_header(std::ostringstream & ss) { + ss << "| Argument | Explanation |\n"; + ss << "| -------- | ----------- |\n"; } -static void write_table_entry(std::ofstream & file, const common_arg & opt) { - file << "| `"; +static void write_table_entry(std::ostringstream & ss, const common_arg & opt) { + ss << "| `"; // args auto all_args = opt.get_args(); for (const auto & arg : all_args) { if (arg == all_args.front()) { - file << arg; - if (all_args.size() > 1) file << ", "; + ss << arg; + if (all_args.size() > 1) ss << ", "; } else { - file << arg << (arg != all_args.back() ? ", " : ""); + ss << arg << (arg != all_args.back() ? ", " : ""); } } // value hint if (opt.value_hint) { std::string md_value_hint(opt.value_hint); string_replace_all(md_value_hint, "|", "\\|"); - file << " " << md_value_hint; + ss << " " << md_value_hint; } if (opt.value_hint_2) { std::string md_value_hint_2(opt.value_hint_2); string_replace_all(md_value_hint_2, "|", "\\|"); - file << " " << md_value_hint_2; + ss << " " << md_value_hint_2; } // help text std::string md_help(opt.help); + md_help = string_strip(md_help); string_replace_all(md_help, "\n", "
"); string_replace_all(md_help, "|", "\\|"); - file << "` | " << md_help << " |\n"; + ss << "` | " << md_help << " |\n"; } -static void write_table(std::ofstream & file, std::vector & opts) { - write_table_header(file); +static void write_table(std::ostringstream & ss, std::vector & opts) { + write_table_header(ss); for (const auto & opt : opts) { - write_table_entry(file, *opt); + write_table_entry(ss, *opt); } } -static void export_md(std::string fname, llama_example ex, std::string name) { - std::ofstream file(fname, std::ofstream::out | std::ofstream::trunc); - +static void write_help(std::ostringstream & ss, const md_file & md) { common_params params; - auto ctx_arg = common_params_parser_init(params, ex); + auto ctx_arg = common_params_parser_init(params, md.ex); std::vector common_options; std::vector sparam_options; @@ -68,18 +85,58 @@ static void export_md(std::string fname, llama_example ex, std::string name) { } } - file << "**Common params**\n\n"; - write_table(file, common_options); - file << "\n\n**Sampling params**\n\n"; - write_table(file, sparam_options); - file << "\n\n**" << name << "-specific params**\n\n"; - write_table(file, specific_options); + ss << HELP_START_MARKER << "\n\n"; + + ss << NOTE_MESSAGE << "\n\n"; + + ss << "### Common params\n\n"; + write_table(ss, common_options); + ss << "\n\n### Sampling params\n\n"; + write_table(ss, sparam_options); + ss << "\n\n### " << md.specific_section_header << "\n\n"; + write_table(ss, specific_options); + + ss << "\n" << HELP_END_MARKER; } int main(int, char **) { - // TODO: add CLI - export_md("autogen-completion.md", LLAMA_EXAMPLE_COMPLETION, "Tool"); - export_md("autogen-server.md", LLAMA_EXAMPLE_SERVER, "Server"); + for (const auto & md : md_files) { + std::ifstream infile(md.fname); + if (!infile.is_open()) { + fprintf(stderr, "failed to open file '%s' for reading\n", md.fname.c_str()); + return 1; + } + + std::ostringstream ss; + ss << infile.rdbuf(); + infile.close(); + + std::string content = ss.str(); + + size_t help_start = content.find(HELP_START_MARKER); + size_t help_end = content.find(HELP_END_MARKER); + + if (help_start == std::string::npos || help_end == std::string::npos || help_end <= help_start) { + fprintf(stderr, "failed to find help markers in file '%s'\n", md.fname.c_str()); + return 1; + } + + std::ostringstream new_help_ss; + write_help(new_help_ss, md); + std::string new_help = new_help_ss.str(); + + content = content.substr(0, help_start) + new_help + content.substr(help_end + strlen(HELP_END_MARKER)); + + std::ofstream outfile(md.fname); + if (!outfile.is_open()) { + fprintf(stderr, "failed to open file '%s' for writing\n", md.fname.c_str()); + return 1; + } + outfile << content; + outfile.close(); + + printf("Updated help in '%s'\n", md.fname.c_str()); + } return 0; } diff --git a/examples/model-conversion/Makefile b/examples/model-conversion/Makefile index 25b0514b29..f8dc525a77 100644 --- a/examples/model-conversion/Makefile +++ b/examples/model-conversion/Makefile @@ -25,6 +25,8 @@ define quantize_model @echo "Export the quantized model path to $(2) variable in your environment" endef +DEVICE ?= auto + ### ### Casual Model targets/recipes ### @@ -53,7 +55,7 @@ causal-convert-mm-model: causal-run-original-model: $(call validate_model_path,causal-run-original-model) - @MODEL_PATH="$(MODEL_PATH)" ./scripts/causal/run-org-model.py + @MODEL_PATH="$(MODEL_PATH)" ./scripts/causal/run-org-model.py --device "$(DEVICE)" causal-run-converted-model: @CONVERTED_MODEL="$(CONVERTED_MODEL)" ./scripts/causal/run-converted-model.sh diff --git a/examples/model-conversion/scripts/causal/run-org-model.py b/examples/model-conversion/scripts/causal/run-org-model.py index 14bb12fe68..b12173a1fb 100755 --- a/examples/model-conversion/scripts/causal/run-org-model.py +++ b/examples/model-conversion/scripts/causal/run-org-model.py @@ -4,149 +4,179 @@ import argparse import os import sys import importlib +import torch +import numpy as np + from pathlib import Path +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, AutoConfig # Add parent directory to path for imports sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) - -from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, AutoConfig -import torch -import numpy as np from utils.common import debug_hook -parser = argparse.ArgumentParser(description="Process model with specified path") -parser.add_argument("--model-path", "-m", help="Path to the model") -parser.add_argument("--prompt-file", "-f", help="Optional prompt file", required=False) -parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose debug output") -args = parser.parse_args() +def parse_arguments(): + parser = argparse.ArgumentParser(description="Process model with specified path") + parser.add_argument("--model-path", "-m", help="Path to the model") + parser.add_argument("--prompt-file", "-f", help="Optional prompt file", required=False) + parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose debug output") + parser.add_argument("--device", "-d", help="Device to use (cpu, cuda, mps, auto)", default="auto") + return parser.parse_args() -model_path = os.environ.get("MODEL_PATH", args.model_path) -if model_path is None: - parser.error( - "Model path must be specified either via --model-path argument or MODEL_PATH environment variable" - ) +def load_model_and_tokenizer(model_path, device="auto"): + print("Loading model and tokenizer using AutoTokenizer:", model_path) + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + multimodal = False + full_config = config -### If you want to dump RoPE activations, uncomment the following lines: -### === START ROPE DEBUG === -# from utils.common import setup_rope_debug -# setup_rope_debug("transformers.models.apertus.modeling_apertus") -### == END ROPE DEBUG === - - -print("Loading model and tokenizer using AutoTokenizer:", model_path) -tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) -config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) -multimodal = False -full_config = config - -print("Model type: ", config.model_type) -if "vocab_size" not in config and "text_config" in config: - config = config.text_config - multimodal = True -print("Vocab size: ", config.vocab_size) -print("Hidden size: ", config.hidden_size) -print("Number of layers: ", config.num_hidden_layers) -print("BOS token id: ", config.bos_token_id) -print("EOS token id: ", config.eos_token_id) - -unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME") -if unreleased_model_name: - model_name_lower = unreleased_model_name.lower() - unreleased_module_path = ( - f"transformers.models.{model_name_lower}.modular_{model_name_lower}" - ) - class_name = f"{unreleased_model_name}ForCausalLM" - print(f"Importing unreleased model module: {unreleased_module_path}") - - try: - model_class = getattr( - importlib.import_module(unreleased_module_path), class_name - ) - model = model_class.from_pretrained( - model_path - ) # Note: from_pretrained, not fromPretrained - except (ImportError, AttributeError) as e: - print(f"Failed to import or load model: {e}") - exit(1) -else: - if multimodal: - model = AutoModelForImageTextToText.from_pretrained( - model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=full_config - ) + # Determine device_map based on device argument + if device == "cpu": + device_map = {"": "cpu"} + print("Forcing CPU usage") + elif device == "auto": + device_map = "auto" else: - model = AutoModelForCausalLM.from_pretrained( - model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=config + device_map = {"": device} + + print("Model type: ", config.model_type) + if "vocab_size" not in config and "text_config" in config: + config = config.text_config + multimodal = True + + print("Vocab size: ", config.vocab_size) + print("Hidden size: ", config.hidden_size) + print("Number of layers: ", config.num_hidden_layers) + print("BOS token id: ", config.bos_token_id) + print("EOS token id: ", config.eos_token_id) + + unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME") + if unreleased_model_name: + model_name_lower = unreleased_model_name.lower() + unreleased_module_path = ( + f"transformers.models.{model_name_lower}.modular_{model_name_lower}" ) + class_name = f"{unreleased_model_name}ForCausalLM" + print(f"Importing unreleased model module: {unreleased_module_path}") -if args.verbose: - for name, module in model.named_modules(): - if len(list(module.children())) == 0: # only leaf modules - module.register_forward_hook(debug_hook(name)) + try: + model_class = getattr(importlib.import_module(unreleased_module_path), class_name) + model = model_class.from_pretrained( + model_path, + device_map=device_map, + offload_folder="offload", + trust_remote_code=True, + config=config + ) + except (ImportError, AttributeError) as e: + print(f"Failed to import or load model: {e}") + exit(1) + else: + if multimodal: + model = AutoModelForImageTextToText.from_pretrained( + model_path, + device_map=device_map, + offload_folder="offload", + trust_remote_code=True, + config=full_config + ) + else: + model = AutoModelForCausalLM.from_pretrained( + model_path, + device_map=device_map, + offload_folder="offload", + trust_remote_code=True, + config=config + ) -model_name = os.path.basename(model_path) -# Printing the Model class to allow for easier debugging. This can be useful -# when working with models that have not been publicly released yet and this -# migth require that the concrete class is imported and used directly instead -# of using AutoModelForCausalLM. -print(f"Model class: {model.__class__.__name__}") + print(f"Model class: {model.__class__.__name__}") -device = next(model.parameters()).device -if args.prompt_file: - with open(args.prompt_file, encoding='utf-8') as f: - prompt = f.read() -elif os.getenv("MODEL_TESTING_PROMPT"): - prompt = os.getenv("MODEL_TESTING_PROMPT") -else: - prompt = "Hello, my name is" -input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) + return model, tokenizer, config -print(f"Input tokens: {input_ids}") -print(f"Input text: {repr(prompt)}") -print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}") +def enable_torch_debugging(model): + for name, module in model.named_modules(): + if len(list(module.children())) == 0: # only leaf modules + module.register_forward_hook(debug_hook(name)) -batch_size = 512 +def get_prompt(args): + if args.prompt_file: + with open(args.prompt_file, encoding='utf-8') as f: + return f.read() + elif os.getenv("MODEL_TESTING_PROMPT"): + return os.getenv("MODEL_TESTING_PROMPT") + else: + return "Hello, my name is" -with torch.no_grad(): - past = None - outputs = None - for i in range(0, input_ids.size(1), batch_size): - print(f"Processing chunk with tokens {i} to {i + batch_size}") - chunk = input_ids[:, i:i + batch_size] - outputs = model(chunk.to(model.device), past_key_values=past, use_cache=True) - past = outputs.past_key_values +def main(): + args = parse_arguments() + model_path = os.environ.get("MODEL_PATH", args.model_path) + if model_path is None: + print("Error: Model path must be specified either via --model-path argument or MODEL_PATH environment variable") + sys.exit(1) - logits = outputs.logits # type: ignore - # Extract logits for the last token (next token prediction) - last_logits = logits[0, -1, :].float().cpu().numpy() + model, tokenizer, config = load_model_and_tokenizer(model_path, args.device) - print(f"Logits shape: {logits.shape}") - print(f"Last token logits shape: {last_logits.shape}") - print(f"Vocab size: {len(last_logits)}") + if args.verbose: + enable_torch_debugging(model) - data_dir = Path("data") - data_dir.mkdir(exist_ok=True) - bin_filename = data_dir / f"pytorch-{model_name}.bin" - txt_filename = data_dir / f"pytorch-{model_name}.txt" + model_name = os.path.basename(model_path) - # Save to file for comparison - last_logits.astype(np.float32).tofile(bin_filename) + # Iterate over the model parameters (the tensors) and get the first one + # and use it to get the device the model is on. + device = next(model.parameters()).device + prompt = get_prompt(args) + input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) - # Also save as text file for easy inspection - with open(txt_filename, "w") as f: - for i, logit in enumerate(last_logits): - f.write(f"{i}: {logit:.6f}\n") + print(f"Input tokens: {input_ids}") + print(f"Input text: {repr(prompt)}") + print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}") - # Print some sample logits for quick verification - print(f"First 10 logits: {last_logits[:10]}") - print(f"Last 10 logits: {last_logits[-10:]}") + batch_size = 512 - # Show top 5 predicted tokens - top_indices = np.argsort(last_logits)[-5:][::-1] - print("Top 5 predictions:") - for idx in top_indices: - token = tokenizer.decode([idx]) - print(f" Token {idx} ({repr(token)}): {last_logits[idx]:.6f}") + with torch.no_grad(): + past = None + outputs = None + for i in range(0, input_ids.size(1), batch_size): + print(f"Processing chunk with tokens {i} to {i + batch_size}") + chunk = input_ids[:, i:i + batch_size] + outputs = model(chunk.to(model.device), past_key_values=past, use_cache=True) + past = outputs.past_key_values - print(f"Saved bin logits to: {bin_filename}") - print(f"Saved txt logist to: {txt_filename}") + logits = outputs.logits # type: ignore + + # Extract logits for the last token (next token prediction) + last_logits = logits[0, -1, :].float().cpu().numpy() + + print(f"Logits shape: {logits.shape}") + print(f"Last token logits shape: {last_logits.shape}") + print(f"Vocab size: {len(last_logits)}") + + data_dir = Path("data") + data_dir.mkdir(exist_ok=True) + bin_filename = data_dir / f"pytorch-{model_name}.bin" + txt_filename = data_dir / f"pytorch-{model_name}.txt" + + # Save to file for comparison + last_logits.astype(np.float32).tofile(bin_filename) + + # Also save as text file for easy inspection + with open(txt_filename, "w") as f: + for i, logit in enumerate(last_logits): + f.write(f"{i}: {logit:.6f}\n") + + # Print some sample logits for quick verification + print(f"First 10 logits: {last_logits[:10]}") + print(f"Last 10 logits: {last_logits[-10:]}") + + # Show top 5 predicted tokens + top_indices = np.argsort(last_logits)[-5:][::-1] + print("Top 5 predictions:") + for idx in top_indices: + token = tokenizer.decode([idx]) + print(f" Token {idx} ({repr(token)}): {last_logits[idx]:.6f}") + + print(f"Saved bin logits to: {bin_filename}") + print(f"Saved txt logist to: {txt_filename}") + +if __name__ == "__main__": + main() diff --git a/examples/model-conversion/scripts/embedding/run-original-model.py b/examples/model-conversion/scripts/embedding/run-original-model.py index 640e200a97..39f054d0e0 100755 --- a/examples/model-conversion/scripts/embedding/run-original-model.py +++ b/examples/model-conversion/scripts/embedding/run-original-model.py @@ -45,7 +45,7 @@ if use_sentence_transformers: else: tokenizer = AutoTokenizer.from_pretrained(model_path) - config = AutoConfig.from_pretrained(model_path) + config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) # This can be used to override the sliding window size for manual testing. This # can be useful to verify the sliding window attention mask in the original model @@ -64,12 +64,12 @@ else: try: model_class = getattr(importlib.import_module(unreleased_module_path), class_name) - model = model_class.from_pretrained(model_path, config=config) + model = model_class.from_pretrained(model_path, config=config, trust_remote_code=True) except (ImportError, AttributeError) as e: print(f"Failed to import or load model: {e}") exit(1) else: - model = AutoModel.from_pretrained(model_path, config=config) + model = AutoModel.from_pretrained(model_path, config=config, trust_remote_code=True) print(f"Model class: {type(model)}") print(f"Model file: {type(model).__module__}") @@ -123,7 +123,7 @@ with torch.no_grad(): outputs = model(**encoded) hidden_states = outputs.last_hidden_state # Shape: [batch_size, seq_len, hidden_size] - all_embeddings = hidden_states[0].cpu().numpy() # Shape: [seq_len, hidden_size] + all_embeddings = hidden_states[0].float().cpu().numpy() # Shape: [seq_len, hidden_size] print(f"Hidden states shape: {hidden_states.shape}") print(f"All embeddings shape: {all_embeddings.shape}") diff --git a/examples/model-conversion/scripts/utils/semantic_check.py b/examples/model-conversion/scripts/utils/semantic_check.py index 2ac8b6b7b4..e64c000497 100644 --- a/examples/model-conversion/scripts/utils/semantic_check.py +++ b/examples/model-conversion/scripts/utils/semantic_check.py @@ -166,7 +166,7 @@ def main(): # Load the python model to get configuration information and also to load the tokenizer. print("Loading model and tokenizer using AutoTokenizer:", args.model_path) tokenizer = AutoTokenizer.from_pretrained(args.model_path) - config = AutoConfig.from_pretrained(args.model_path) + config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True) if unreleased_model_name: model_name_lower = unreleased_model_name.lower() @@ -186,9 +186,9 @@ def main(): exit(1) else: if args.causal: - model = AutoModelForCausalLM.from_pretrained(args.model_path) + model = AutoModelForCausalLM.from_pretrained(args.model_path, trust_remote_code=True) else: - model = AutoModel.from_pretrained(args.model_path) + model = AutoModel.from_pretrained(args.model_path, trust_remote_code=True) encoded = tokenizer(prompt, return_tensors="pt") tokens = tokenizer.convert_ids_to_tokens(encoded['input_ids'][0]) diff --git a/examples/sycl/run-llama2.sh b/examples/sycl/run-llama2.sh index a018e45197..cf23619ee0 100755 --- a/examples/sycl/run-llama2.sh +++ b/examples/sycl/run-llama2.sh @@ -22,9 +22,9 @@ if [ $# -gt 0 ]; then GGML_SYCL_DEVICE=$1 echo "use $GGML_SYCL_DEVICE as main GPU" #use signle GPU only - ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none + ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none else #use multiple GPUs with same max compute units - ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} + ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} fi diff --git a/examples/sycl/run-llama3.sh b/examples/sycl/run-llama3.sh index 4770255703..feee5165e9 100755 --- a/examples/sycl/run-llama3.sh +++ b/examples/sycl/run-llama3.sh @@ -24,8 +24,8 @@ export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 if [ $# -gt 0 ]; then GGML_SYCL_DEVICE=$1 echo "Using $GGML_SYCL_DEVICE as the main GPU" - ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none + ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none else #use multiple GPUs with same max compute units - ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} + ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} fi diff --git a/examples/sycl/win-run-llama2.bat b/examples/sycl/win-run-llama2.bat index b654f88f62..32ff673ae2 100644 --- a/examples/sycl/win-run-llama2.bat +++ b/examples/sycl/win-run-llama2.bat @@ -8,4 +8,4 @@ set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:" :: support malloc device memory more than 4GB. set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 -.\build\bin\llama-cli.exe -m models\llama-2-7b.Q4_0.gguf -p %INPUT2% -n 400 -e -ngl 99 -s 0 +.\build\bin\llama-completion.exe -m models\llama-2-7b.Q4_0.gguf -no-cnv -p %INPUT2% -n 400 -e -ngl 99 -s 0 diff --git a/examples/sycl/win-run-llama3.bat b/examples/sycl/win-run-llama3.bat index 608b834f60..ea4ae69d6c 100644 --- a/examples/sycl/win-run-llama3.bat +++ b/examples/sycl/win-run-llama3.bat @@ -8,4 +8,4 @@ set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:" :: support malloc device memory more than 4GB. set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 -.\build\bin\llama-cli.exe -m models\Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf -p %INPUT2% -n 400 -s 0 -e -ngl 99 +.\build\bin\llama-completion.exe -m models\Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf -no-cnv -p %INPUT2% -n 400 -s 0 -e -ngl 99 diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 835b53f659..dff72a277a 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -2338,19 +2338,19 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx, // Step1.2: prepare rope_yarn_ramp, if this part updated, should update theta_scale_tensor. // TODO: acl_yarn_ramp_tensor use rope cache. bool yarn_ramp_tensor_updated = false; - ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool()); acl_tensor_ptr acl_yarn_ramp_tensor; if (ext_factor != 0 && (theta_scale_updated || ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.freq_scale != freq_scale)) { yarn_ramp_tensor_updated = true; - + if (ctx.rope_cache.yarn_ramp_cache != nullptr) { + ACL_CHECK(aclrtFree(ctx.rope_cache.yarn_ramp_cache)); + } + ACL_CHECK(aclrtMalloc(&ctx.rope_cache.yarn_ramp_cache, theta_scale_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST)); // -rope_yarn_ramp // const float y = (i0 / 2 - low) / MAX(0.001f, high - low); // return MIN(1, MAX(0, y)) - 1; - yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float)); - void * yarn_ramp_buffer = yarn_ramp_allocator.get(); acl_yarn_ramp_tensor = - ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1); + ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1); float zero_value = 0, one_value = 1; float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]); acl_scalar_ptr low = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT); @@ -2380,8 +2380,10 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx, acl_scalar_ptr freq_scale_1_sc = ggml_cann_create_scalar(&freq_scale_1, aclDataType::ACL_FLOAT); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get()); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get()); + } else { + acl_yarn_ramp_tensor = + ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1); } - // Step 1.3: update theta_scale_tensor according to ext_factor or freq_scale. if (ext_factor != 0) { if (theta_scale_updated || yarn_ramp_tensor_updated) { diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index 45c7294e68..3a461ef1a7 100644 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -318,6 +318,9 @@ struct ggml_cann_rope_cache { if (position_select_index_host) { free(position_select_index_host); } + if (yarn_ramp_cache) { + ACL_CHECK(aclrtFree(yarn_ramp_cache)); + } } bool equal(int64_t theta_scale_length, @@ -370,6 +373,7 @@ struct ggml_cann_rope_cache { float * theta_scale_exp_host = nullptr; int * position_select_index_host = nullptr; void * position_select_index = nullptr; + void * yarn_ramp_cache = nullptr; // sin/cos cache, used only to accelerate first layer on each device void * sin_cache = nullptr; void * cos_cache = nullptr; diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index a0cce10aa7..7dc36d4f8a 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -69,6 +69,10 @@ #define VECTOR_REGISTERS 16 #endif +#if defined(__riscv_v_intrinsic) +#define LMUL 4 +#endif + #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) namespace { @@ -175,6 +179,46 @@ inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) { } #endif +#if defined(__riscv_zvfh) +template <> +inline vfloat32m1_t madd(vfloat16mf2_t a, vfloat16mf2_t b, vfloat32m1_t c) { + return __riscv_vfwmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1()); +} +inline vfloat32m2_t madd(vfloat16m1_t a, vfloat16m1_t b, vfloat32m2_t c) { + return __riscv_vfwmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2()); +} +inline vfloat32m4_t madd(vfloat16m2_t a, vfloat16m2_t b, vfloat32m4_t c) { + return __riscv_vfwmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4()); +} +inline vfloat32m8_t madd(vfloat16m4_t a, vfloat16m4_t b, vfloat32m8_t c) { + return __riscv_vfwmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8()); +} +inline vfloat32m1_t madd(vfloat32m1_t a, vfloat32m1_t b, vfloat32m1_t c) { + return __riscv_vfmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1()); +} +inline vfloat32m2_t madd(vfloat32m2_t a, vfloat32m2_t b, vfloat32m2_t c) { + return __riscv_vfmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2()); +} +inline vfloat32m4_t madd(vfloat32m4_t a, vfloat32m4_t b, vfloat32m4_t c) { + return __riscv_vfmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4()); +} +inline vfloat32m8_t madd(vfloat32m8_t a, vfloat32m8_t b, vfloat32m8_t c) { + return __riscv_vfmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8()); +} +#endif + +#if defined(__riscv_zvfbfwma) +inline vfloat32m1_t madd(vbfloat16mf2_t a, vbfloat16mf2_t b, vfloat32m1_t c) { + return __riscv_vfwmaccbf16_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1()); +} +inline vfloat32m2_t madd(vbfloat16m1_t a, vbfloat16m1_t b, vfloat32m2_t c) { + return __riscv_vfwmaccbf16_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2()); +} +inline vfloat32m4_t madd(vbfloat16m2_t a, vbfloat16m2_t b, vfloat32m4_t c) { + return __riscv_vfwmaccbf16_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4()); +} +#endif + //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED HORIZONTAL SUM @@ -227,6 +271,25 @@ inline float hsum(__m512 x) { } #endif // __AVX512F__ +#if defined(__riscv_zvfh) +inline float hsum(vfloat32m1_t x) { + return __riscv_vfmv_f_s_f32m1_f32( + __riscv_vfredusum_vs_f32m1_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m1())); +} +inline float hsum(vfloat32m2_t x) { + return __riscv_vfmv_f_s_f32m1_f32( + __riscv_vfredusum_vs_f32m2_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m2())); +} +inline float hsum(vfloat32m4_t x) { + return __riscv_vfmv_f_s_f32m1_f32( + __riscv_vfredusum_vs_f32m4_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m4())); +} +inline float hsum(vfloat32m8_t x) { + return __riscv_vfmv_f_s_f32m1_f32( + __riscv_vfredusum_vs_f32m8_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m8())); +} +#endif + //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED MEMORY LOADING @@ -315,6 +378,88 @@ template <> inline __m256bh load(const float *p) { } #endif +#if defined(__riscv_zvfh) +template <> inline vfloat16mf2_t load(const ggml_fp16_t *p) { + return __riscv_vle16_v_f16mf2(reinterpret_cast(p), __riscv_vsetvlmax_e16mf2()); +} +template <> inline vfloat16m1_t load(const ggml_fp16_t *p) { + return __riscv_vle16_v_f16m1(reinterpret_cast(p), __riscv_vsetvlmax_e16m1()); +} +template <> inline vfloat16m2_t load(const ggml_fp16_t *p) { + return __riscv_vle16_v_f16m2(reinterpret_cast(p), __riscv_vsetvlmax_e16m2()); +} +template <> inline vfloat16m4_t load(const ggml_fp16_t *p) { + return __riscv_vle16_v_f16m4(reinterpret_cast(p), __riscv_vsetvlmax_e16m4()); +} +template <> inline vfloat32m1_t load(const float *p) { + return __riscv_vle32_v_f32m1(p, __riscv_vsetvlmax_e32m1()); +} +template <> inline vfloat32m2_t load(const float *p) { + return __riscv_vle32_v_f32m2(p, __riscv_vsetvlmax_e32m2()); +} +template <> inline vfloat32m4_t load(const float *p) { + return __riscv_vle32_v_f32m4(p, __riscv_vsetvlmax_e32m4()); +} +template <> inline vfloat32m8_t load(const float *p) { + return __riscv_vle32_v_f32m8(p, __riscv_vsetvlmax_e32m8()); +} +#endif + +#if defined(__riscv_zvfbfwma) +template <> inline vbfloat16mf2_t load(const ggml_bf16_t *p) { + return __riscv_vle16_v_bf16mf2(reinterpret_cast(p), __riscv_vsetvlmax_e16mf2()); +} +template <> inline vbfloat16m1_t load(const ggml_bf16_t *p) { + return __riscv_vle16_v_bf16m1(reinterpret_cast(p), __riscv_vsetvlmax_e16m1()); +} +template <> inline vbfloat16m2_t load(const ggml_bf16_t *p) { + return __riscv_vle16_v_bf16m2(reinterpret_cast(p), __riscv_vsetvlmax_e16m2()); +} +#endif + +#if defined(__riscv_zvfh) +template T set_zero(); + +template <> inline vfloat16mf2_t set_zero() { + return __riscv_vfmv_v_f_f16mf2(0, __riscv_vsetvlmax_e16mf2()); +} +template <> inline vfloat16m1_t set_zero() { + return __riscv_vfmv_v_f_f16m1(0, __riscv_vsetvlmax_e16m1()); +} +template <> inline vfloat16m2_t set_zero() { + return __riscv_vfmv_v_f_f16m2(0, __riscv_vsetvlmax_e16m2()); +} +template <> inline vfloat16m4_t set_zero() { + return __riscv_vfmv_v_f_f16m4(0, __riscv_vsetvlmax_e16m4()); +} +template <> inline vfloat32m1_t set_zero() { + return __riscv_vfmv_v_f_f32m1(0.0f, __riscv_vsetvlmax_e32m1()); +} +template <> inline vfloat32m2_t set_zero() { + return __riscv_vfmv_v_f_f32m2(0, __riscv_vsetvlmax_e32m2()); +} +template <> inline vfloat32m4_t set_zero() { + return __riscv_vfmv_v_f_f32m4(0, __riscv_vsetvlmax_e32m4()); +} +template <> inline vfloat32m8_t set_zero() { + return __riscv_vfmv_v_f_f32m8(0, __riscv_vsetvlmax_e32m8()); +} +#endif + +#if defined(__riscv_v_intrinsic) +template size_t vlmax() { + if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16mf2(); } + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16m1(); } + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16m2(); } + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16m4(); } + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e32m1(); } + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e32m2(); } + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e32m4(); } + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e32m8(); } + return 0; +} +#endif + //////////////////////////////////////////////////////////////////////////////////////////////////// // FLOATING POINT MATRIX MULTIPLICATION @@ -488,6 +633,573 @@ class tinyBLAS { const int64_t ldc; }; +#if defined(__riscv_v_intrinsic) +template +class tinyBLAS_RVV { + public: + tinyBLAS_RVV(const ggml_compute_params * params, int64_t k, + const TA *A, int64_t lda, + const TB *B, int64_t ldb, + TC *C, int64_t ldc) + : params(params), A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc) { + } + + bool matmul(int64_t m, int64_t n) { + if (k % vlmax() != 0) { + return false; + } + +#if LMUL == 1 + if (m % 16 == 0 && (m/16 >= params->nth)) { + const int64_t SIZE_N = BLOCK_SIZE<6>(n); + mnpack<4, 6, 4>(m, n, SIZE_N, 12); + return true; + } + if (m % 8 == 0 ) { + const int64_t SIZE_N = BLOCK_SIZE<6>(n); + mnpack<4, 6, 2>(m, n, SIZE_N, 12); + return true; + } + if (m % 4 == 0) { + const int64_t SIZE_N = BLOCK_SIZE<6>(n); + mnpack<4, 6, 1>(m, n, SIZE_N, 12); + return true; + } +#elif LMUL == 2 + if (m % 16 == 0 && (m/16 >= params->nth)) { + const int64_t SIZE_N = BLOCK_SIZE<3>(n); + mnpack<4, 3, 4>(m, n, SIZE_N, 24); + return true; + } + if (m % 8 == 0 ) { + const int64_t SIZE_N = BLOCK_SIZE<3>(n); + mnpack<4, 3, 2>(m, n, SIZE_N, 24); + return true; + } + if (m % 4 == 0) { + const int64_t SIZE_N = BLOCK_SIZE<3>(n); + mnpack<4, 3, 1>(m, n, SIZE_N, 24); + return true; + } +#else // LMUL = 4 + if (m % 16 == 0 && (m/16 >= params->nth)) { + const int64_t SIZE_N = BLOCK_SIZE<2>(n); + mnpack<2, 2, 8>(m, n, SIZE_N, 36); + return true; + } + if (m % 8 == 0 ) { + const int64_t SIZE_N = BLOCK_SIZE<2>(n); + mnpack<2, 2, 4>(m, n, SIZE_N, 36); + return true; + } + if (m % 4 == 0) { + const int64_t SIZE_N = BLOCK_SIZE<2>(n); + mnpack<2, 2, 2>(m, n, SIZE_N, 36); + return true; + } +#endif + return false; + } + + private: + template + inline void mnpack(int64_t m, int64_t n, int64_t SIZE_N, int64_t BN) { + if (SIZE_N == RN) { + return gemm(m, n, BN); + } + if constexpr (RN > 1) { + return mnpack(m, n, SIZE_N, BN); + } else { + GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N); + GGML_ASSERT(false); // we have miss something. + } + } + + inline void gemm_bloc_4x6(int64_t ii, int64_t jj) { + size_t vl = vlmax(); + D Cv00 = set_zero(); + D Cv01 = set_zero(); + D Cv02 = set_zero(); + D Cv03 = set_zero(); + D Cv10 = set_zero(); + D Cv11 = set_zero(); + D Cv12 = set_zero(); + D Cv13 = set_zero(); + D Cv20 = set_zero(); + D Cv21 = set_zero(); + D Cv22 = set_zero(); + D Cv23 = set_zero(); + D Cv30 = set_zero(); + D Cv31 = set_zero(); + D Cv32 = set_zero(); + D Cv33 = set_zero(); + D Cv40 = set_zero(); + D Cv41 = set_zero(); + D Cv42 = set_zero(); + D Cv43 = set_zero(); + D Cv50 = set_zero(); + D Cv51 = set_zero(); + D Cv52 = set_zero(); + D Cv53 = set_zero(); + + for (int64_t l = 0; l < k; l += vl) { + V Bv0 = load(B + ldb * (jj + 0) + l); + V Bv1 = load(B + ldb * (jj + 1) + l); + V Bv2 = load(B + ldb * (jj + 2) + l); + V Bv3 = load(B + ldb * (jj + 3) + l); + V Bv4 = load(B + ldb * (jj + 4) + l); + V Bv5 = load(B + ldb * (jj + 5) + l); + + V Av0 = load(A + lda * (ii + 0) + l); + Cv00 = madd(Av0, Bv0, Cv00); + Cv10 = madd(Av0, Bv1, Cv10); + Cv20 = madd(Av0, Bv2, Cv20); + Cv30 = madd(Av0, Bv3, Cv30); + Cv40 = madd(Av0, Bv4, Cv40); + Cv50 = madd(Av0, Bv5, Cv50); + + V Av1 = load(A + lda * (ii + 1) + l); + Cv01 = madd(Av1, Bv0, Cv01); + Cv11 = madd(Av1, Bv1, Cv11); + Cv21 = madd(Av1, Bv2, Cv21); + Cv31 = madd(Av1, Bv3, Cv31); + Cv41 = madd(Av1, Bv4, Cv41); + Cv51 = madd(Av1, Bv5, Cv51); + + V Av2 = load(A + lda * (ii + 2) + l); + Cv02 = madd(Av2, Bv0, Cv02); + Cv12 = madd(Av2, Bv1, Cv12); + Cv22 = madd(Av2, Bv2, Cv22); + Cv32 = madd(Av2, Bv3, Cv32); + Cv42 = madd(Av2, Bv4, Cv42); + Cv52 = madd(Av2, Bv5, Cv52); + + V Av3 = load(A + lda * (ii + 3) + l); + Cv03 = madd(Av3, Bv0, Cv03); + Cv13 = madd(Av3, Bv1, Cv13); + Cv23 = madd(Av3, Bv2, Cv23); + Cv33 = madd(Av3, Bv3, Cv33); + Cv43 = madd(Av3, Bv4, Cv43); + Cv53 = madd(Av3, Bv5, Cv53); + } + + C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00); + C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01); + C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02); + C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03); + C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10); + C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11); + C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12); + C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13); + C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20); + C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21); + C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22); + C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23); + C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30); + C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31); + C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32); + C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33); + C[ldc * (jj + 4) + (ii + 0)] = hsum(Cv40); + C[ldc * (jj + 4) + (ii + 1)] = hsum(Cv41); + C[ldc * (jj + 4) + (ii + 2)] = hsum(Cv42); + C[ldc * (jj + 4) + (ii + 3)] = hsum(Cv43); + C[ldc * (jj + 5) + (ii + 0)] = hsum(Cv50); + C[ldc * (jj + 5) + (ii + 1)] = hsum(Cv51); + C[ldc * (jj + 5) + (ii + 2)] = hsum(Cv52); + C[ldc * (jj + 5) + (ii + 3)] = hsum(Cv53); + } + + inline void gemm_bloc_4x5(int64_t ii, int64_t jj) { + size_t vl = vlmax(); + D Cv00 = set_zero(); + D Cv01 = set_zero(); + D Cv02 = set_zero(); + D Cv03 = set_zero(); + D Cv10 = set_zero(); + D Cv11 = set_zero(); + D Cv12 = set_zero(); + D Cv13 = set_zero(); + D Cv20 = set_zero(); + D Cv21 = set_zero(); + D Cv22 = set_zero(); + D Cv23 = set_zero(); + D Cv30 = set_zero(); + D Cv31 = set_zero(); + D Cv32 = set_zero(); + D Cv33 = set_zero(); + D Cv40 = set_zero(); + D Cv41 = set_zero(); + D Cv42 = set_zero(); + D Cv43 = set_zero(); + + for (int64_t l = 0; l < k; l += vl) { + V Bv0 = load(B + ldb * (jj + 0) + l); + V Bv1 = load(B + ldb * (jj + 1) + l); + V Bv2 = load(B + ldb * (jj + 2) + l); + V Bv3 = load(B + ldb * (jj + 3) + l); + V Bv4 = load(B + ldb * (jj + 4) + l); + + V Av0 = load(A + lda * (ii + 0) + l); + Cv00 = madd(Av0, Bv0, Cv00); + Cv10 = madd(Av0, Bv1, Cv10); + Cv20 = madd(Av0, Bv2, Cv20); + Cv30 = madd(Av0, Bv3, Cv30); + Cv40 = madd(Av0, Bv4, Cv40); + + V Av1 = load(A + lda * (ii + 1) + l); + Cv01 = madd(Av1, Bv0, Cv01); + Cv11 = madd(Av1, Bv1, Cv11); + Cv21 = madd(Av1, Bv2, Cv21); + Cv31 = madd(Av1, Bv3, Cv31); + Cv41 = madd(Av1, Bv4, Cv41); + + V Av2 = load(A + lda * (ii + 2) + l); + Cv02 = madd(Av2, Bv0, Cv02); + Cv12 = madd(Av2, Bv1, Cv12); + Cv22 = madd(Av2, Bv2, Cv22); + Cv32 = madd(Av2, Bv3, Cv32); + Cv42 = madd(Av2, Bv4, Cv42); + + V Av3 = load(A + lda * (ii + 3) + l); + Cv03 = madd(Av3, Bv0, Cv03); + Cv13 = madd(Av3, Bv1, Cv13); + Cv23 = madd(Av3, Bv2, Cv23); + Cv33 = madd(Av3, Bv3, Cv33); + Cv43 = madd(Av3, Bv4, Cv43); + } + + C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00); + C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01); + C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02); + C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03); + C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10); + C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11); + C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12); + C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13); + C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20); + C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21); + C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22); + C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23); + C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30); + C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31); + C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32); + C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33); + C[ldc * (jj + 4) + (ii + 0)] = hsum(Cv40); + C[ldc * (jj + 4) + (ii + 1)] = hsum(Cv41); + C[ldc * (jj + 4) + (ii + 2)] = hsum(Cv42); + C[ldc * (jj + 4) + (ii + 3)] = hsum(Cv43); + } + + inline void gemm_bloc_4x4(int64_t ii, int64_t jj) { + size_t vl = vlmax(); + D Cv00 = set_zero(); + D Cv01 = set_zero(); + D Cv02 = set_zero(); + D Cv03 = set_zero(); + D Cv10 = set_zero(); + D Cv11 = set_zero(); + D Cv12 = set_zero(); + D Cv13 = set_zero(); + D Cv20 = set_zero(); + D Cv21 = set_zero(); + D Cv22 = set_zero(); + D Cv23 = set_zero(); + D Cv30 = set_zero(); + D Cv31 = set_zero(); + D Cv32 = set_zero(); + D Cv33 = set_zero(); + + for (int64_t l = 0; l < k; l += vl) { + V Av0 = load(A + lda * (ii + 0) + l); + V Av1 = load(A + lda * (ii + 1) + l); + V Av2 = load(A + lda * (ii + 2) + l); + V Av3 = load(A + lda * (ii + 3) + l); + + V Bv0 = load(B + ldb * (jj + 0) + l); + Cv00 = madd(Av0, Bv0, Cv00); + Cv01 = madd(Av1, Bv0, Cv01); + Cv02 = madd(Av2, Bv0, Cv02); + Cv03 = madd(Av3, Bv0, Cv03); + + V Bv1 = load(B + ldb * (jj + 1) + l); + Cv10 = madd(Av0, Bv1, Cv10); + Cv11 = madd(Av1, Bv1, Cv11); + Cv12 = madd(Av2, Bv1, Cv12); + Cv13 = madd(Av3, Bv1, Cv13); + + V Bv2 = load(B + ldb * (jj + 2) + l); + Cv20 = madd(Av0, Bv2, Cv20); + Cv21 = madd(Av1, Bv2, Cv21); + Cv22 = madd(Av2, Bv2, Cv22); + Cv23 = madd(Av3, Bv2, Cv23); + + V Bv3 = load(B + ldb * (jj + 3) + l); + Cv30 = madd(Av0, Bv3, Cv30); + Cv31 = madd(Av1, Bv3, Cv31); + Cv32 = madd(Av2, Bv3, Cv32); + Cv33 = madd(Av3, Bv3, Cv33); + } + + C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00); + C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01); + C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02); + C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03); + C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10); + C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11); + C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12); + C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13); + C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20); + C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21); + C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22); + C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23); + C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30); + C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31); + C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32); + C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33); + } + + inline void gemm_bloc_4x3(int64_t ii, int64_t jj) { + size_t vl = vlmax(); + D Cv00 = set_zero(); + D Cv01 = set_zero(); + D Cv02 = set_zero(); + D Cv03 = set_zero(); + D Cv10 = set_zero(); + D Cv11 = set_zero(); + D Cv12 = set_zero(); + D Cv13 = set_zero(); + D Cv20 = set_zero(); + D Cv21 = set_zero(); + D Cv22 = set_zero(); + D Cv23 = set_zero(); + + for (int64_t l = 0; l < k; l += vl) { + V Av0 = load(A + lda * (ii + 0) + l); + V Av1 = load(A + lda * (ii + 1) + l); + V Av2 = load(A + lda * (ii + 2) + l); + V Av3 = load(A + lda * (ii + 3) + l); + + V Bv0 = load(B + ldb * (jj + 0) + l); + Cv00 = madd(Av0, Bv0, Cv00); + Cv01 = madd(Av1, Bv0, Cv01); + Cv02 = madd(Av2, Bv0, Cv02); + Cv03 = madd(Av3, Bv0, Cv03); + + V Bv1 = load(B + ldb * (jj + 1) + l); + Cv10 = madd(Av0, Bv1, Cv10); + Cv11 = madd(Av1, Bv1, Cv11); + Cv12 = madd(Av2, Bv1, Cv12); + Cv13 = madd(Av3, Bv1, Cv13); + + V Bv2 = load(B + ldb * (jj + 2) + l); + Cv20 = madd(Av0, Bv2, Cv20); + Cv21 = madd(Av1, Bv2, Cv21); + Cv22 = madd(Av2, Bv2, Cv22); + Cv23 = madd(Av3, Bv2, Cv23); + } + + C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00); + C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01); + C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02); + C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03); + C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10); + C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11); + C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12); + C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13); + C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20); + C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21); + C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22); + C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23); + } + + inline void gemm_bloc_4x2(int64_t ii, int64_t jj) { + size_t vl = vlmax(); + D Cv00 = set_zero(); + D Cv01 = set_zero(); + D Cv02 = set_zero(); + D Cv03 = set_zero(); + D Cv10 = set_zero(); + D Cv11 = set_zero(); + D Cv12 = set_zero(); + D Cv13 = set_zero(); + + for (int64_t l = 0; l < k; l += vl) { + V Av0 = load(A + lda * (ii + 0) + l); + V Av1 = load(A + lda * (ii + 1) + l); + V Av2 = load(A + lda * (ii + 2) + l); + V Av3 = load(A + lda * (ii + 3) + l); + + V Bv0 = load(B + ldb * (jj + 0) + l); + Cv00 = madd(Av0, Bv0, Cv00); + Cv01 = madd(Av1, Bv0, Cv01); + Cv02 = madd(Av2, Bv0, Cv02); + Cv03 = madd(Av3, Bv0, Cv03); + + V Bv1 = load(B + ldb * (jj + 1) + l); + Cv10 = madd(Av0, Bv1, Cv10); + Cv11 = madd(Av1, Bv1, Cv11); + Cv12 = madd(Av2, Bv1, Cv12); + Cv13 = madd(Av3, Bv1, Cv13); + } + + C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00); + C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01); + C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02); + C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03); + C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10); + C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11); + C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12); + C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13); + } + + inline void gemm_bloc_4x1(int64_t ii, int64_t jj) { + size_t vl = vlmax(); + D Cv00 = set_zero(); + D Cv01 = set_zero(); + D Cv02 = set_zero(); + D Cv03 = set_zero(); + + for (int64_t l = 0; l < k; l += vl) { + V Av0 = load(A + lda * (ii + 0) + l); + V Av1 = load(A + lda * (ii + 1) + l); + V Av2 = load(A + lda * (ii + 2) + l); + V Av3 = load(A + lda * (ii + 3) + l); + + V Bv0 = load(B + ldb * (jj + 0) + l); + Cv00 = madd(Av0, Bv0, Cv00); + Cv01 = madd(Av1, Bv0, Cv01); + Cv02 = madd(Av2, Bv0, Cv02); + Cv03 = madd(Av3, Bv0, Cv03); + } + + C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00); + C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01); + C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02); + C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03); + } + + inline void gemm_bloc_2x2(int64_t ii, int64_t jj) { + size_t vl = vlmax(); + D Cv00 = set_zero(); + D Cv01 = set_zero(); + D Cv10 = set_zero(); + D Cv11 = set_zero(); + + for (int64_t l = 0; l < k; l += vl) { + V Av0 = load(A + lda * (ii + 0) + l); + V Av1 = load(A + lda * (ii + 1) + l); + + V Bv0 = load(B + ldb * (jj + 0) + l); + Cv00 = madd(Av0, Bv0, Cv00); + Cv01 = madd(Av1, Bv0, Cv01); + + V Bv1 = load(B + ldb * (jj + 1) + l); + Cv10 = madd(Av0, Bv1, Cv10); + Cv11 = madd(Av1, Bv1, Cv11); + } + + C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00); + C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01); + C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10); + C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11); + } + + inline void gemm_bloc_2x1(int64_t ii, int64_t jj) { + size_t vl = vlmax(); + D Cv00 = set_zero(); + D Cv01 = set_zero(); + + for (int64_t l = 0; l < k; l += vl) { + V Av0 = load(A + lda * (ii + 0) + l); + V Av1 = load(A + lda * (ii + 1) + l); + + V Bv0 = load(B + ldb * (jj + 0) + l); + Cv00 = madd(Av0, Bv0, Cv00); + Cv01 = madd(Av1, Bv0, Cv01); + } + + C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00); + C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01); + } + + template + inline void gemm_bloc(int64_t ii, int64_t jj) { + if constexpr (RM == 4) { + if constexpr (RN == 6) { return gemm_bloc_4x6(ii, jj); } + if constexpr (RN == 5) { return gemm_bloc_4x5(ii, jj); } + if constexpr (RN == 4) { return gemm_bloc_4x4(ii, jj); } + if constexpr (RN == 3) { return gemm_bloc_4x3(ii, jj); } + if constexpr (RN == 2) { return gemm_bloc_4x2(ii, jj); } + if constexpr (RN == 1) { return gemm_bloc_4x1(ii, jj); } + } else if constexpr (RM == 2) { + if constexpr (RN == 2) { return gemm_bloc_2x2(ii, jj); } + if constexpr (RN == 1) { return gemm_bloc_2x1(ii, jj); } + } + } + + template + NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) { + GGML_ASSERT(m % (RM * BM) == 0); + const int64_t ytiles = m / (RM * BM); + const int64_t xtiles = (n + RN -1) / RN; + const int64_t jj_RN = (xtiles - (xtiles * RN - n)); + + // "round" bloc_size to "nearest" BN + const int64_t NB_BN = xtiles < BN ? 1 : (xtiles + BN / 2) / BN; + const int64_t SIZE_BN = xtiles % NB_BN == 0 ? xtiles / NB_BN : xtiles / NB_BN + 1; + const int64_t jj_BN = (NB_BN - (NB_BN * SIZE_BN - xtiles)); + const int64_t nb_job = ytiles * NB_BN; + + if (params->ith == 0) { + GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles); + // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. + ggml_threadpool_chunk_set(params->threadpool, params->nth); + } + + ggml_barrier(params->threadpool); + + int64_t job = params->ith; + while (job < nb_job) { + const int64_t ii = (job % ytiles) * RM * BM; + const int64_t jb = job / ytiles; + const int64_t jr0 = BLOC_POS(jb , jj_BN, SIZE_BN); + const int64_t jrN = BLOC_POS(jb+1, jj_BN, SIZE_BN); + + const int64_t jj0 = BLOC_POS(jr0, jj_RN, RN); + const int64_t jj2 = BLOC_POS(jrN, jj_RN, RN); + const int64_t jj1 = jj2 < jj_RN * RN ? jj2 : jj_RN * RN; + + for (int64_t bi = 0; bi < BM * RM; bi += RM) { + int64_t jj = jj0; + for (; jj < jj1; jj += RN) { + gemm_bloc(ii + bi, jj); + } + if constexpr (RN > 1) { + for (; jj < jj2; jj += RN - 1) { + gemm_bloc(ii + bi, jj); + } + } + GGML_ASSERT(jj == jj2); + } + + job = ggml_threadpool_chunk_add(params->threadpool, 1); + } + + ggml_barrier(params->threadpool); + return; + } + + const ggml_compute_params * params; + const TA *const A; + const TB *const B; + TC *const C; + const int64_t k; + const int64_t lda; + const int64_t ldb; + const int64_t ldc; +}; +#endif + ////////////////////////////////////////////////////////////////////////////////////////// // QUANT ZERO MATRIX MULTIPLICATION @@ -2657,6 +3369,24 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 params->ith, params->nth}; tb.matmul(m, n); return true; +#elif defined(__riscv_zvfh) + #if LMUL == 1 + tinyBLAS_RVV tb{ params, + k, (const float *)A, lda, + (const float *)B, ldb, + (float *)C, ldc}; + #elif LMUL == 2 + tinyBLAS_RVV tb{ params, + k, (const float *)A, lda, + (const float *)B, ldb, + (float *)C, ldc}; + #else // LMUL = 4 + tinyBLAS_RVV tb{ params, + k, (const float *)A, lda, + (const float *)B, ldb, + (float *)C, ldc}; + #endif + return tb.matmul(m, n); #else return false; #endif @@ -2699,6 +3429,24 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 tb.matmul(m, n); return true; } +#elif defined(__riscv_zvfbfwma) + #if LMUL == 1 + tinyBLAS_RVV tb{ params, + k, (const ggml_bf16_t *)A, lda, + (const ggml_bf16_t *)B, ldb, + (float *)C, ldc}; + #elif LMUL == 2 + tinyBLAS_RVV tb{ params, + k, (const ggml_bf16_t *)A, lda, + (const ggml_bf16_t *)B, ldb, + (float *)C, ldc}; + #else // LMUL = 4 + tinyBLAS_RVV tb{ params, + k, (const ggml_bf16_t *)A, lda, + (const ggml_bf16_t *)B, ldb, + (float *)C, ldc}; + #endif + return tb.matmul(m, n); #endif return false; } @@ -2748,6 +3496,26 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 (float *)C, ldc}; return tb.matmul(m, n); } +#elif defined(__riscv_zvfh) + if (Btype == GGML_TYPE_F16) { + #if LMUL == 1 + tinyBLAS_RVV tb{ params, + k, (const ggml_fp16_t *)A, lda, + (const ggml_fp16_t *)B, ldb, + (float *)C, ldc}; + #elif LMUL == 2 + tinyBLAS_RVV tb{ params, + k, (const ggml_fp16_t *)A, lda, + (const ggml_fp16_t *)B, ldb, + (float *)C, ldc}; + #else // LMUL = 4 + tinyBLAS_RVV tb{ params, + k, (const ggml_fp16_t *)A, lda, + (const ggml_fp16_t *)B, ldb, + (float *)C, ldc}; + #endif + return tb.matmul(m, n); + } #endif return false; } diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 6a00abacc3..13b96d61f8 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -7,9 +7,10 @@ #include #include +#include #include -#include #include +#include #ifdef _WIN32 # include @@ -36,6 +37,7 @@ #include "ggml-hexagon.h" #include "ggml-impl.h" #include "ggml-quants.h" +#include "op-desc.h" #include "htp-msg.h" #include "htp_iface.h" @@ -55,9 +57,6 @@ static int opt_opsync = 0; // synchronous ops #define HEX_VERBOSE(...) \ if (opt_verbose) GGML_LOG_DEBUG(__VA_ARGS__) -#define HEX_PROFILE(...) \ - if (opt_profile) GGML_LOG_INFO(__VA_ARGS__) - static inline uint64_t hex_is_aligned(void * addr, uint32_t align) { return ((size_t) addr & (align - 1)) == 0; } @@ -85,128 +84,30 @@ static const char * status_to_str(uint32_t status) { // ** debug helpers -static inline int hex_format_tensor_dims(char * str, const struct ggml_tensor * t) { - if (t->ne[2] == 1 && t->ne[3] == 1) { - return sprintf(str, "%d:%d", (int) t->ne[0], (int) t->ne[1]); - } else { - return sprintf(str, "%d:%d:%d:%d", (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]); - } +static void ggml_hexagon_dump_op_exec(const std::string &sess_name, const ggml_tensor * op, const uint32_t req_flags) { + if (!opt_verbose) return; + + op_desc desc(op); + GGML_LOG_DEBUG("ggml-hex: %s execute-op %s: %s : %s : %s : %s : %s : flags 0x%x\n", sess_name.c_str(), + ggml_op_name(op->op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, req_flags); } -static inline void hex_format_op_dims(char * str, const struct ggml_tensor * t) { - char * p = str; +static void ggml_hexagon_dump_op_supp(const std::string &sess_name, const struct ggml_tensor * op, bool supp) { + if (!opt_verbose) return; - // append src0 and src1 (if any) - if (t->src[0]) { - p += hex_format_tensor_dims(p, t->src[0]); - - for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) { - p += sprintf(p, " x "); - p += hex_format_tensor_dims(p, t->src[i]); - } - - p += sprintf(p, " -> "); - } - - // format self dims separately for better visual alignment - char self[64]; - hex_format_tensor_dims(self, t); - - p += sprintf(p, "%s", self); + op_desc desc(op); + GGML_LOG_DEBUG("ggml-hex: %s supports-op %s : %s : %s : %s : %s : %s : %s\n", sess_name.c_str(), + ggml_op_name(op->op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, supp ? "yes" : "no"); } -static inline int hex_format_tensor_strides(char * str, const struct ggml_tensor * t) { - const char * c = ggml_is_contiguous(t) ? "" : "!"; +static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const ggml_tensor * op, + uint32_t op_usec, uint32_t op_cycles, uint32_t op_pkts, uint64_t call_usec) { + if (!opt_profile) return; - if (t->ne[2] == 1 && t->ne[3] == 1) { - return sprintf(str, "%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], c); - } else { - return sprintf(str, "%zu:%zu:%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], (size_t) t->nb[2], - (size_t) t->nb[3], c); - } -} - -static inline void hex_format_op_strides(char * str, const struct ggml_tensor * t) { - char * p = str; - - // append src0 and src1 (if any) - if (t->src[0]) { - p += hex_format_tensor_strides(p, t->src[0]); - - for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) { - p += sprintf(p, " x "); - p += hex_format_tensor_strides(p, t->src[i]); - } - - p += sprintf(p, " -> "); - } - - // format self dims separately for better visual alignment - char self[64]; - hex_format_tensor_strides(self, t); - - p += sprintf(p, "%s", self); -} - -static inline void hex_format_op_types(char * str, const struct ggml_tensor * t) { - char * p = str; - - // append src0 and src1 (if any) - if (t->src[0]) { - p += sprintf(p, "%s", ggml_type_name(t->src[0]->type)); - - for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) { - p += sprintf(p, " x "); - p += sprintf(p, "%s", ggml_type_name(t->src[i]->type)); - } - - p += sprintf(p, " -> "); - } - - p += sprintf(p, "%s", ggml_type_name(t->type)); -} - -static inline const char * hex_tensor_buff_name(const struct ggml_tensor * t) { - if (t->buffer) { - return ggml_backend_buffer_name(t->buffer); - } - return "NONE"; -} - -static inline void hex_format_op_buffs(char * str, const struct ggml_tensor * t) { - char * p = str; - - // append src0 and src1 (if any) - if (t->src[0]) { - p += sprintf(p, "%s", hex_tensor_buff_name(t->src[0])); - - for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) { - p += sprintf(p, " x "); - p += sprintf(p, "%s", hex_tensor_buff_name(t->src[i])); - } - - p += sprintf(p, " -> "); - } - - p += sprintf(p, "%s", hex_tensor_buff_name(t)); -} - -static inline void hex_format_op_names(char * str, const struct ggml_tensor * t) { - char * p = str; - - // append src0 and src1 (if any) - if (t->src[0]) { - p += sprintf(p, "%s", t->src[0]->name); - - for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) { - p += sprintf(p, " x "); - p += sprintf(p, "%s", t->src[i]->name); - } - - p += sprintf(p, " -> "); - } - - p += sprintf(p, "%s", t->name); + op_desc desc(op); + GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : %s : op-usec %u op-cycles %u op-pkts %u (%f) call-usec %llu\n", sess_name.c_str(), + ggml_op_name(op->op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, + op_usec, op_cycles, op_pkts, (float) op_cycles / op_pkts, (unsigned long long) call_usec); } // ** backend sessions @@ -221,8 +122,8 @@ struct ggml_hexagon_session { void enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync = false); void flush(); - ggml_backend_buffer_type buffer_type; - ggml_backend_buffer_type repack_buffer_type; + ggml_backend_buffer_type buffer_type = {}; + ggml_backend_buffer_type repack_buffer_type = {}; std::string name; remote_handle64 handle; @@ -241,23 +142,6 @@ struct ggml_hexagon_session { uint32_t prof_pkts; }; -static inline void hex_print_op_info(const ggml_tensor * op, ggml_hexagon_session * sess, const uint32_t req_flags) { - char dims[64 * GGML_MAX_SRC]; - char strides[64 * GGML_MAX_SRC]; - char types[16 * GGML_MAX_SRC]; - char buffs[64 * GGML_MAX_SRC]; - char names[64 * GGML_MAX_SRC]; - - hex_format_op_dims(dims, op); - hex_format_op_strides(strides, op); - hex_format_op_types(types, op); - hex_format_op_buffs(buffs, op); - hex_format_op_names(names, op); - - HEX_VERBOSE("ggml-hex: %s %s: %s : %s : %s : %s : %s: flags 0x%x\n", sess->name.c_str(), ggml_op_name(op->op), - names, dims, types, strides, buffs, req_flags); -} - void ggml_hexagon_session::enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync) { // Bump pending flag (cleared in the session::flush once we get the responce) this->op_pending++; // atomic inc @@ -1598,7 +1482,7 @@ static ggml_backend_buffer_t ggml_backend_hexagon_buffer_type_alloc_buffer( try { ggml_backend_hexagon_buffer_context * ctx = new ggml_backend_hexagon_buffer_context(sess, size, false /*repack*/); return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, ctx, size); - } catch (std::exception const &exc) { + } catch (const std::exception & exc) { GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer context: %s\n", sess->name.c_str(), exc.what()); return nullptr; } @@ -1610,7 +1494,7 @@ static ggml_backend_buffer_t ggml_backend_hexagon_repack_buffer_type_alloc_buffe try { ggml_backend_hexagon_buffer_context * ctx = new ggml_backend_hexagon_buffer_context(sess, size, true /*repack*/); return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, ctx, size); - } catch (std::exception const &exc) { + } catch (const std::exception & exc) { GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer context: %s\n", sess->name.c_str(), exc.what()); return nullptr; } @@ -1697,8 +1581,8 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { } // Save the IDs - this->session_id = n.session_id; - this->domain_id = n.effective_domain_id; + this->session_id = n.session_id; + this->domain_id = n.effective_domain_id; this->valid_session = true; } @@ -1751,7 +1635,7 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { this->valid_handle = true; GGML_LOG_INFO("ggml-hex: new session: %s : session-id %d domain-id %d uri %s handle 0x%lx\n", this->name.c_str(), - this->session_id, this->domain_id, session_uri, (unsigned long) this->handle); + this->session_id, this->domain_id, session_uri, (unsigned long) this->handle); // Enable FastRPC QoS mode { @@ -1838,11 +1722,8 @@ void ggml_hexagon_session::release() noexcept(true) { } ggml_hexagon_session::ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false) { - buffer_type.context = nullptr; - repack_buffer_type.context = nullptr; - - buffer_type.device = dev; - repack_buffer_type.device = dev; + buffer_type.device = dev; + repack_buffer_type.device = dev; try { allocate(dev_id); @@ -1852,7 +1733,7 @@ ggml_hexagon_session::ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) n repack_buffer_type.iface = ggml_backend_hexagon_repack_buffer_type_interface; repack_buffer_type.context = new ggml_backend_hexagon_buffer_type_context(this->name + "-REPACK", this); - } catch (std::exception const &exc) { + } catch (const std::exception & exc) { release(); throw; } @@ -1861,8 +1742,8 @@ ggml_hexagon_session::ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) n ggml_hexagon_session::~ggml_hexagon_session() noexcept(true) { release(); - delete static_cast(buffer_type.context); - delete static_cast(repack_buffer_type.context); + delete static_cast(buffer_type.context); + delete static_cast(repack_buffer_type.context); } // ** backend interface @@ -1930,15 +1811,6 @@ static bool hex_supported_dims(const struct ggml_tensor * x, const struct ggml_t return true; } -template -static inline bool hex_supported_buffer(const struct ggml_hexagon_session * sess, _TTensor... tensors) { - return ([&]() -> bool { - return !tensors || !tensors->buffer || - (ggml_backend_buffer_is_hexagon(tensors->buffer) && - ggml_backend_hexagon_buffer_get_sess(tensors->buffer) == sess); - }() && ...); -} - static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; @@ -1976,17 +1848,16 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s break; case GGML_TYPE_F16: + if (src0->nb[1] < src0->nb[0]) { + GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: permuted F16 src0 not supported\n"); + return false; + } break; default: return false; } - // src0 & src1 & dst must be mapped to the same session - if (!hex_supported_buffer(sess, src0, src1, dst)) { - return false; - } - return true; } @@ -2029,12 +1900,6 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session return false; } - // src0 (weights) must be repacked and mapped to the same session - // src1 & sr2 & dst must be mapped to the same session - if (!hex_supported_buffer(sess, src0, src1, src2, dst)) { - return false; - } - return true; } @@ -2064,18 +1929,12 @@ static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * se return false; } - // src0, src1 & dst must be mapped to the same session - if (!hex_supported_buffer(sess, src0, src1, dst)) { - return false; - } - return true; } static bool ggml_hexagon_supported_add_id(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { const struct ggml_tensor * src0 = op->src[0]; const struct ggml_tensor * src1 = op->src[1]; - const struct ggml_tensor * src2 = op->src[2]; const struct ggml_tensor * dst = op; if (!hex_supported_src0_type(src0->type)) { @@ -2096,11 +1955,6 @@ static bool ggml_hexagon_supported_add_id(const struct ggml_hexagon_session * se return false; } - // src0, src1 & dst must be mapped to the same session - if (!hex_supported_buffer(sess, src0, src1, src2, dst)) { - return false; - } - return true; } @@ -2123,11 +1977,6 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses return false; } - // src0 & dst must be mapped to the same session - if (!hex_supported_buffer(sess, src0, dst)) { - return false; - } - return true; } @@ -2160,17 +2009,6 @@ static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session } } - // src0, src1 & dst must be mapped to the same session - if(src1){ - if (!hex_supported_buffer(sess, src0, src1, dst)) { - return false; - } - }else{ - if (!hex_supported_buffer(sess, src0, dst)) { - return false; - } - } - return true; } @@ -2219,11 +2057,6 @@ static bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * s } } - // src0, src1 & dst must be mapped to the same session - if (!hex_supported_buffer(sess, src0, src1, dst)) { - return false; - } - return true; } @@ -2274,16 +2107,28 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess } } - // src0, src1, src2 & dst must be mapped to the same session - if (!hex_supported_buffer(sess, src0, src1, src2, dst)) { - return false; - } - return true; } +enum dspqbuf_type { + DSPQBUF_TYPE_DSP_WRITE_CPU_READ = 0, + DSPQBUF_TYPE_CPU_WRITE_DSP_READ, + DSPQBUF_TYPE_CONSTANT, +}; + +static void dspqbuf_dump(dspqueue_buffer * d, const struct ggml_tensor * t, dspqbuf_type type) { + if (opt_verbose < 2) return; + + auto buf = static_cast(t->buffer->context); + auto sess = buf->sess; + + GGML_LOG_DEBUG("ggml-hex: %s dspqbuf : %s base-addr %p base-size %zu data %p offset %u size %u\n", sess->name.c_str(), + t->name, (void *) buf->base, buf->size, (void *) d->ptr, (unsigned int) d->offset, + (unsigned int) d->size); +} + // Init hexagon tensor from GGML tensor and Hexagon buffer -static void init_htp_tensor(htp_tensor * h, const ggml_tensor * t) { +static void htp_req_tensor_init(htp_tensor * h, const ggml_tensor * t) { h->data = 0; // updated by the receiver h->type = t->type; h->ne[0] = t->ne[0]; @@ -2296,53 +2141,52 @@ static void init_htp_tensor(htp_tensor * h, const ggml_tensor * t) { h->nb[3] = t->nb[3]; } -static size_t dspqueue_buffers_init(dspqueue_buffer * buf, const ggml_tensor * t, bool flush_host, bool flush_htp) { +static size_t htp_req_buff_init(htp_tensor *h, dspqueue_buffer * d, const ggml_tensor * t, dspqbuf_type type) { if (!t) { return 0; } - memset(buf, 0, sizeof(*buf)); - auto tensor_buf = static_cast(t->buffer->context); - buf->fd = tensor_buf->fd; - buf->ptr = t->data; - buf->offset = (uint8_t *) t->data - tensor_buf->base; - buf->size = ggml_nbytes(t); - buf->flags = (flush_host ? DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER : 0); // Flush CPU - buf->flags |= (flush_htp ? DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT : 0); // Invalidate DSP + auto buf = static_cast(t->buffer->context); + + memset(d, 0, sizeof(*d)); + d->fd = buf->fd; + d->ptr = t->data; + d->offset = (uint8_t *) t->data - buf->base; + d->size = ggml_nbytes(t); + + switch (type) { + case DSPQBUF_TYPE_DSP_WRITE_CPU_READ: + // Flush CPU + d->flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER; + break; + case DSPQBUF_TYPE_CPU_WRITE_DSP_READ: + // Flush CPU, Invalidate DSP + d->flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT; + break; + default: + // Constant buffer, no cache maintenance + d->flags = 0; + break; + } + + htp_req_tensor_init(h, t); + + dspqbuf_dump(d, t, type); + return 1; } -static ggml_hexagon_session * get_session_from_tensor(const ggml_tensor * t) { - return static_cast(t->buffer->context)->sess; -} +typedef size_t (*htp_req_init_func_t)(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * op); -static void hex_dump_dspbuf(const struct ggml_tensor * t, const dspqueue_buffer * d) { - auto buf = static_cast(t->buffer->context); - auto sess = buf->sess; +template +static inline void ggml_hexagon_dispatch_op(ggml_hexagon_session *sess, const struct ggml_tensor * op, uint32_t flags) { + uint64_t t = ggml_time_us(); - HEX_VERBOSE("ggml-hex: %s dspqbuf : %s base-addr %p base-size %zu data %p offset %u size %u\n", sess->name.c_str(), - t->name, (void *) buf->base, buf->size, (void *) d->ptr, (unsigned int) d->offset, - (unsigned int) d->size); -} - -static void ggml_hexagon_mul_mat(const struct ggml_tensor * op, uint32_t flags) { - const struct ggml_tensor * src0 = op->src[0]; - const struct ggml_tensor * src1 = op->src[1]; - const struct ggml_tensor * dst = op; - - uint64_t t1, t2; - t1 = ggml_time_us(); - - // Construct HTP message + // Construct HTP request htp_general_req req; - req.op = HTP_OP_MUL_MAT; + memset(&req, 0, sizeof(req)); + req.flags = flags; - - init_htp_tensor(&req.src0, src0); - init_htp_tensor(&req.src1, src1); - init_htp_tensor(&req.dst, dst); - - // Use opmask to override flags if (!(opt_opmask & HTP_OPMASK_QUANTIZE)) { req.flags |= HTP_OPFLAGS_SKIP_QUANTIZE; } @@ -2350,342 +2194,111 @@ static void ggml_hexagon_mul_mat(const struct ggml_tensor * op, uint32_t flags) req.flags |= HTP_OPFLAGS_SKIP_COMPUTE; } - dspqueue_buffer bufs[3]; - - // First buffer Weights. - // The content is static, there is no need to do any cache management - dspqueue_buffers_init(bufs, src0, false, false); - - // Second buffer Input Activations. This is a buffer that the CPU - // writes and the DSP reads, so we'll need to flush CPU caches and - // invalidate DSP ones. On platforms with I/O coherency support the - // framework will automatically skip cache operations where possible. - dspqueue_buffers_init(&bufs[1], src1, true, true); - - // Third buffer Output Activations. We'll handle DSP - // cache maintenance in the response message but need to flush - // CPU caches to ensure any previously written dirty lines are - // written out before writes from the DSP start. - dspqueue_buffers_init(&bufs[2], dst, true, false); - - auto * sess = get_session_from_tensor(src0); - - if (opt_verbose) { - hex_print_op_info(op, sess, req.flags); - if (opt_verbose > 1) { - hex_dump_dspbuf(src0, &bufs[0]); - hex_dump_dspbuf(src1, &bufs[1]); - hex_dump_dspbuf(dst, &bufs[2]); - } - } + ggml_hexagon_dump_op_exec(sess->name, op, req.flags); if ((opt_opmask & HTP_OPMASK_QUEUE)) { - sess->enqueue(req, bufs, 3, opt_opsync); + dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS]; + size_t n_bufs = _init_req_func(&req, bufs, op); + sess->enqueue(req, bufs, n_bufs, opt_opsync); } - t2 = ggml_time_us(); + t = ggml_time_us() - t; - HEX_PROFILE( - "ggml-hex: %s %s %s %u:%u:%u:%u x %s %u:%u:%u:%u -> %s %u:%u:%u:%u : op-usec %u op-cycles %u op-pkts %u (%f) " - "call-usec %llu\n", - sess->name.c_str(), ggml_op_name(op->op), src0->name, (uint32_t) src0->ne[0], (uint32_t) src0->ne[1], - (uint32_t) src0->ne[2], (uint32_t) src0->ne[3], src1->name, (uint32_t) src1->ne[0], (uint32_t) src1->ne[1], - (uint32_t) src1->ne[2], (uint32_t) src1->ne[3], dst->name, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], - (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], sess->prof_usecs, sess->prof_cycles, sess->prof_pkts, - (float) sess->prof_cycles / sess->prof_pkts, (unsigned long long) t2 - t1); + ggml_hexagon_dump_op_prof(sess->name, op, sess->prof_usecs, sess->prof_cycles, sess->prof_pkts, t); } -static void ggml_hexagon_mul_mat_id(const struct ggml_tensor * op, uint32_t flags) { - const struct ggml_tensor * src0 = op->src[0]; - const struct ggml_tensor * src1 = op->src[1]; - const struct ggml_tensor * src2 = op->src[2]; - const struct ggml_tensor * dst = op; - - uint64_t t1, t2; - t1 = ggml_time_us(); - - // Construct HTP message - htp_general_req req; - req.op = HTP_OP_MUL_MAT_ID; - req.flags = flags; - - init_htp_tensor(&req.src0, src0); - init_htp_tensor(&req.src1, src1); - init_htp_tensor(&req.src2, src2); - init_htp_tensor(&req.dst, dst); - - // Use opmask to override flags - if (!(opt_opmask & HTP_OPMASK_QUANTIZE)) { - req.flags |= HTP_OPFLAGS_SKIP_QUANTIZE; - } - if (!(opt_opmask & HTP_OPMASK_COMPUTE)) { - req.flags |= HTP_OPFLAGS_SKIP_COMPUTE; - } - - dspqueue_buffer bufs[4]; - // First buffer Weights. - // The content is static, there is no need to do any cache management - dspqueue_buffers_init(bufs, src0, false, false); - - // Second buffer Input Activations. This is a buffer that the CPU - // writes and the DSP reads, so we'll need to flush CPU caches and - // invalidate DSP ones. On platforms with I/O coherency support the - // framework will automatically skip cache operations where possible. - dspqueue_buffers_init(&bufs[1], src1, true, true); - - // Third buffer expert IDs. This is a buffer that the CPU - // writes and the DSP reads, so we'll need to flush CPU caches and - // invalidate DSP ones. On platforms with I/O coherency support the - // framework will automatically skip cache operations where possible. - dspqueue_buffers_init(&bufs[2], src2, true, true); - - // Forth buffer Output Activations. We'll handle DSP - // cache maintenance in the response message but need to flush - // CPU caches to ensure any previously written dirty lines are - // written out before writes from the DSP start. - dspqueue_buffers_init(&bufs[3], dst, true, false); - - auto * sess = get_session_from_tensor(src0); - - if (opt_verbose) { - hex_print_op_info(op, sess, req.flags); - if (opt_verbose > 1) { - hex_dump_dspbuf(src0, &bufs[0]); - hex_dump_dspbuf(src1, &bufs[1]); - hex_dump_dspbuf(src2, &bufs[2]); - hex_dump_dspbuf(dst, &bufs[3]); - } - } - - if ((opt_opmask & HTP_OPMASK_QUEUE)) { - sess->enqueue(req, bufs, 4, opt_opsync); - } - - t2 = ggml_time_us(); - - HEX_PROFILE( - "ggml-hex: %s matmul-id %s %u:%u:%u:%u x %s %u:%u:%u:%u (%s %u:%u:%u:%u) -> %s %u:%u:%u:%u : op-usec %u " - "op-cycles %u op-pkts %u (%f) call-usec %llu\n", - sess->name.c_str(), src0->name, (uint32_t) src0->ne[0], (uint32_t) src0->ne[1], (uint32_t) src0->ne[2], - (uint32_t) src0->ne[3], src1->name, (uint32_t) src1->ne[0], (uint32_t) src1->ne[1], (uint32_t) src1->ne[2], - (uint32_t) src1->ne[3], src2->name, (uint32_t) src2->ne[0], (uint32_t) src2->ne[1], (uint32_t) src2->ne[2], - (uint32_t) src2->ne[3], dst->name, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], - (uint32_t) dst->ne[3], sess->prof_usecs, sess->prof_cycles, sess->prof_pkts, - (float) sess->prof_cycles / sess->prof_pkts, (unsigned long long) t2 - t1); -} - -static void ggml_hexagon_binary(const struct ggml_tensor * op, uint32_t flags) { - const struct ggml_tensor * node = op; - const struct ggml_tensor * src0 = node->src[0]; - const struct ggml_tensor * src1 = node->src[1]; - const struct ggml_tensor * dst = node; - - uint64_t t1 = 0; - uint64_t t2 = 0; - - t1 = ggml_time_us(); - - // Construct HTP message - htp_general_req req; - req.flags = flags; - - // Use opmask to override flags - if (!(opt_opmask & HTP_OPMASK_QUANTIZE)) { - req.flags |= HTP_OPFLAGS_SKIP_QUANTIZE; - } - if (!(opt_opmask & HTP_OPMASK_COMPUTE)) { - req.flags |= HTP_OPFLAGS_SKIP_COMPUTE; - } - - switch (node->op) { +template +static inline size_t init_binary_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + switch (t->op) { + case GGML_OP_MUL_MAT: + req->op = HTP_OP_MUL_MAT; + break; case GGML_OP_MUL: - req.op = HTP_OP_MUL; + req->op = HTP_OP_MUL; break; case GGML_OP_ADD: - req.op = HTP_OP_ADD; + req->op = HTP_OP_ADD; break; case GGML_OP_SUB: - req.op = HTP_OP_SUB; + req->op = HTP_OP_SUB; break; default: - GGML_ABORT("ggml-hex: binary : unsupported op:%d\n", node->op); + GGML_ABORT("ggml-hex: binary : unsupported op: %d\n", t->op); + break; } - init_htp_tensor(&req.src0, src0); - init_htp_tensor(&req.src1, src1); - init_htp_tensor(&req.dst, dst); + // src0: Weights (mulmat) or First Operand (binary op). + // If constant (e.g. weights), no cache management is needed. + // src1: Input Activations (mulmat) or Second Operand (binary op). - dspqueue_buffer bufs[3]; - // First buffer = First Operand of Binary op - // This is a buffer that the CPU writes and the DSP reads, so we'll - // need to flush CPU caches and invalidate DSP ones. On platforms - // with I/O coherency support the framework will automatically skip - // cache operations where possible. - dspqueue_buffers_init(bufs, src0, true, true); + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], _is_src0_constant ? DSPQBUF_TYPE_CONSTANT : DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - // Second buffer = Second Operand of Binary op - // This is a buffer that the CPU writes and the DSP reads, so we'll - // need to flush CPU caches and invalidate DSP ones. On platforms - // with I/O coherency support the framework will automatically skip - // cache operations where possible. - dspqueue_buffers_init(&bufs[1], src1, true, true); - - // Third buffer = Output Activations. We'll handle DSP - // cache maintenance in the response message but need to flush - // CPU caches to ensure any previously written dirty lines are - // written out before writes from the DSP start. - dspqueue_buffers_init(&bufs[2], dst, true, false); - - auto * sess = get_session_from_tensor(src0); - - if (opt_verbose) { - hex_print_op_info(op, sess, req.flags); - if (opt_verbose > 1) { - hex_dump_dspbuf(src0, &bufs[0]); - hex_dump_dspbuf(src1, &bufs[1]); - hex_dump_dspbuf(dst, &bufs[2]); - } - } - - if ((opt_opmask & HTP_OPMASK_QUEUE)) { - sess->enqueue(req, bufs, 3, opt_opsync); - } - - t2 = ggml_time_us(); - - HEX_PROFILE( - "ggml-hex: %s %s %s %u:%u:%u:%u x %s %u:%u:%u:%u -> %s %u:%u:%u:%u : op-usec %u op-cycles %u op-pkts %u (%f) " - "call-usec %llu\n", - sess->name.c_str(), ggml_op_name(node->op), src0->name, (uint32_t) src0->ne[0], (uint32_t) src0->ne[1], - (uint32_t) src0->ne[2], (uint32_t) src0->ne[3], src1->name, (uint32_t) src1->ne[0], (uint32_t) src1->ne[1], - (uint32_t) src1->ne[2], (uint32_t) src1->ne[3], dst->name, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], - (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], sess->prof_usecs, sess->prof_cycles, sess->prof_pkts, - (float) sess->prof_cycles / sess->prof_pkts, (unsigned long long) t2 - t1); + return n_bufs; } -static void ggml_hexagon_add_id(const struct ggml_tensor * op, uint32_t flags) { - const struct ggml_tensor * node = op; - const struct ggml_tensor * src0 = node->src[0]; - const struct ggml_tensor * src1 = node->src[1]; - const struct ggml_tensor * src2 = node->src[2]; - const struct ggml_tensor * dst = node; - - uint64_t t1 = 0; - uint64_t t2 = 0; - - t1 = ggml_time_us(); - - // Construct HTP message - htp_general_req req; - req.flags = flags; - - // Use opmask to override flags - if (!(opt_opmask & HTP_OPMASK_QUANTIZE)) { - req.flags |= HTP_OPFLAGS_SKIP_QUANTIZE; - } - if (!(opt_opmask & HTP_OPMASK_COMPUTE)) { - req.flags |= HTP_OPFLAGS_SKIP_COMPUTE; - } - - switch (node->op) { +template +static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + switch (t->op) { + case GGML_OP_MUL_MAT_ID: + req->op = HTP_OP_MUL_MAT_ID; + break; case GGML_OP_ADD_ID: - req.op = HTP_OP_ADD_ID; + req->op = HTP_OP_ADD_ID; break; default: - GGML_ABORT("ggml-hex: unsupported op:%d\n", node->op); + GGML_ABORT("ggml-hex: unsupported op: %d\n", t->op); } - init_htp_tensor(&req.src0, src0); - init_htp_tensor(&req.src1, src1); - init_htp_tensor(&req.src2, src2); - init_htp_tensor(&req.dst, dst); + // src0: Weights (mulmat) or Input Activations (other op). + // If constant, no cache management is needed. + // src1: Input Activations (mulmat) or Second Operand (binary op). + // src2: Expert IDs (mulmat) or Activated Experts (other op). - dspqueue_buffer bufs[4]; - // First buffer = input activations - dspqueue_buffers_init(bufs, src0, true, true); - // Second buffer = experts bias - dspqueue_buffers_init(&bufs[1], src1, true, true); - // Third buffer = activated experts - dspqueue_buffers_init(&bufs[2], src2, true, true); - // Forth buffer = output activations - dspqueue_buffers_init(&bufs[3], dst, true, true); + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], _is_src0_constant ? DSPQBUF_TYPE_CONSTANT : DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - auto * sess = get_session_from_tensor(src0); - - if (opt_verbose) { - hex_print_op_info(op, sess, req.flags); - if (opt_verbose > 1) { - hex_dump_dspbuf(src0, &bufs[0]); - hex_dump_dspbuf(src1, &bufs[1]); - hex_dump_dspbuf(src2, &bufs[2]); - hex_dump_dspbuf(dst, &bufs[3]); - } - } - - if ((opt_opmask & HTP_OPMASK_QUEUE)) { - sess->enqueue(req, bufs, 4, opt_opsync); - } - - t2 = ggml_time_us(); - - HEX_PROFILE( - "ggml-hex: %s %s %s %u:%u:%u:%u x %s %u:%u:%u:%u -> %s %u:%u:%u:%u : op-usec %u op-cycles %u op-pkts %u (%f) " - "call-usec %llu\n", - sess->name.c_str(), ggml_op_name(node->op), src0->name, (uint32_t) src0->ne[0], (uint32_t) src0->ne[1], - (uint32_t) src0->ne[2], (uint32_t) src0->ne[3], src1->name, (uint32_t) src1->ne[0], (uint32_t) src1->ne[1], - (uint32_t) src1->ne[2], (uint32_t) src1->ne[3], dst->name, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], - (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], sess->prof_usecs, sess->prof_cycles, sess->prof_pkts, - (float) sess->prof_cycles / sess->prof_pkts, (unsigned long long) t2 - t1); + return n_bufs; } -static void ggml_hexagon_unary(const struct ggml_tensor * op, uint32_t flags) { - const struct ggml_tensor * src0 = op->src[0]; - const struct ggml_tensor * src1 = op->src[1]; - const struct ggml_tensor * dst = op; - - uint64_t t1 = 0; - uint64_t t2 = 0; - - t1 = ggml_time_us(); - - // Construct HTP message - htp_general_req req; - - memset(&req, 0, sizeof(htp_general_req)); - memcpy(&req.op_params, &op->op_params, sizeof(op->op_params)); - req.flags = flags; +static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); bool supported = false; - switch (op->op) { + switch (t->op) { case GGML_OP_RMS_NORM: - req.op = HTP_OP_RMS_NORM; + req->op = HTP_OP_RMS_NORM; supported = true; break; case GGML_OP_UNARY: - if (ggml_get_unary_op(dst) == GGML_UNARY_OP_SILU) { - req.op = HTP_OP_UNARY_SILU; + if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) { + req->op = HTP_OP_UNARY_SILU; supported = true; - } - else if (ggml_get_unary_op(dst) == GGML_UNARY_OP_GELU){ - req.op = HTP_OP_UNARY_GELU; + } else if (ggml_get_unary_op(t) == GGML_UNARY_OP_GELU) { + req->op = HTP_OP_UNARY_GELU; supported = true; } break; case GGML_OP_GLU: - if (ggml_get_glu_op(dst) == GGML_GLU_OP_SWIGLU) { - req.op = HTP_OP_GLU_SWIGLU; + if (ggml_get_glu_op(t) == GGML_GLU_OP_SWIGLU) { + req->op = HTP_OP_GLU_SWIGLU; supported = true; - } else if (ggml_get_glu_op(dst) == GGML_GLU_OP_SWIGLU_OAI) { - req.op = HTP_OP_GLU_SWIGLU_OAI; + } else if (ggml_get_glu_op(t) == GGML_GLU_OP_SWIGLU_OAI) { + req->op = HTP_OP_GLU_SWIGLU_OAI; supported = true; } break; case GGML_OP_SOFT_MAX: - req.op = HTP_OP_SOFTMAX; + req->op = HTP_OP_SOFTMAX; supported = true; break; @@ -2694,194 +2307,28 @@ static void ggml_hexagon_unary(const struct ggml_tensor * op, uint32_t flags) { } if (!supported) { - GGML_ABORT("ggml-hex: unary : unsupported op:%d\n", op->op); + GGML_ABORT("ggml-hex: unary : unsupported op: %d\n", t->op); } - init_htp_tensor(&req.dst, dst); - init_htp_tensor(&req.src0, src0); - if (src1) { - init_htp_tensor(&req.src1, src1); - } + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - // Use opmask to override flags - if (!(opt_opmask & HTP_OPMASK_QUANTIZE)) { - req.flags |= HTP_OPFLAGS_SKIP_QUANTIZE; - } - if (!(opt_opmask & HTP_OPMASK_COMPUTE)) { - req.flags |= HTP_OPFLAGS_SKIP_COMPUTE; - } - - dspqueue_buffer bufs[3]; - - // First buffer = Only Operand of Unary op - // This is a buffer that the CPU writes and the DSP reads, so we'll - // need to flush CPU caches and invalidate DSP ones. On platforms - // with I/O coherency support the framework will automatically skip - // cache operations where possible. - size_t n_bufs = dspqueue_buffers_init(bufs, src0, true, true); - - // Second buffer(nullable) = Second Operand of Binary op - // This is a buffer that the CPU writes and the DSP reads, so we'll - // need to flush CPU caches and invalidate DSP ones. On platforms - // with I/O coherency support the framework will automatically skip - // cache operations where possible. - n_bufs += dspqueue_buffers_init(&bufs[n_bufs], src1, true, true); - - // Second or third buffer = Output Activations. We'll handle DSP - // Second buffer = Output Activations. We'll handle DSP - // cache maintenance in the response message but need to flush - // CPU caches to ensure any previously written dirty lines are - // written out before writes from the DSP start. - n_bufs += dspqueue_buffers_init(&bufs[n_bufs], dst, true, false); - - // Primary DSP session from the src0 tensor - auto * sess = get_session_from_tensor(src0); - - if (opt_verbose) { - hex_print_op_info(op, sess, req.flags); - if (opt_verbose > 1) { - hex_dump_dspbuf(src0, &bufs[0]); - if (src1) { - hex_dump_dspbuf(src1, &bufs[1]); - hex_dump_dspbuf(dst, &bufs[2]); - } else { - hex_dump_dspbuf(dst, &bufs[1]); - } - } - } - - if ((opt_opmask & HTP_OPMASK_QUEUE)) { - sess->enqueue(req, bufs, n_bufs, opt_opsync); - } - - t2 = ggml_time_us(); - - if (src1) { - HEX_PROFILE( - "ggml-hex: %s %s %s %u:%u:%u:%u x %s %u:%u:%u:%u -> %s %u:%u:%u:%u : op-usec %u op-cycles %u op-pkts %u " - "(%f) call-usec %llu\n", - sess->name.c_str(), ggml_op_name(op->op), src0->name, (uint32_t) src0->ne[0], (uint32_t) src0->ne[1], - (uint32_t) src0->ne[2], (uint32_t) src0->ne[3], src1->name, (uint32_t) src1->ne[0], (uint32_t) src1->ne[1], - (uint32_t) src1->ne[2], (uint32_t) src1->ne[3], dst->name, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], - (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], sess->prof_usecs, sess->prof_cycles, sess->prof_pkts, - (float) sess->prof_cycles / sess->prof_pkts, (unsigned long long) t2 - t1); - } else { - HEX_PROFILE( - "ggml-hex: %s %s %s %u:%u:%u:%u -> %s %u:%u:%u:%u : op-usec %u op-cycles %u op-pkts %u (%f) call-usec " - "%llu\n", - sess->name.c_str(), ggml_op_name(op->op), src0->name, (uint32_t) src0->ne[0], (uint32_t) src0->ne[1], - (uint32_t) src0->ne[2], (uint32_t) src0->ne[3], dst->name, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], - (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], sess->prof_usecs, sess->prof_cycles, sess->prof_pkts, - (float) sess->prof_cycles / sess->prof_pkts, (unsigned long long) t2 - t1); - } + return n_bufs; } -static void ggml_hexagon_rope(const struct ggml_tensor * op, uint32_t flags) { - const struct ggml_tensor * src0 = op->src[0]; - const struct ggml_tensor * src1 = op->src[1]; - const struct ggml_tensor * src2 = op->src[2]; - const struct ggml_tensor * dst = op; +static inline size_t init_rope_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); + req->op = HTP_OP_ROPE; - uint64_t t1 = 0; - uint64_t t2 = 0; + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - t1 = ggml_time_us(); - - // Construct HTP message - htp_general_req req; - - memset(&req, 0, sizeof(htp_general_req)); - memcpy(&req.op_params, &op->op_params, sizeof(op->op_params)); - req.flags = flags; - req.op = HTP_OP_ROPE; - - init_htp_tensor(&req.dst, dst); - init_htp_tensor(&req.src0, src0); - init_htp_tensor(&req.src1, src1); - if (src2) { - init_htp_tensor(&req.src2, src2); - } - - // Use opmask to override flags - if (!(opt_opmask & HTP_OPMASK_QUANTIZE)) { - req.flags |= HTP_OPFLAGS_SKIP_QUANTIZE; - } - if (!(opt_opmask & HTP_OPMASK_COMPUTE)) { - req.flags |= HTP_OPFLAGS_SKIP_COMPUTE; - } - - dspqueue_buffer bufs[4]; - - // First buffer - // This is a buffer that the CPU writes and the DSP reads, so we'll - // need to flush CPU caches and invalidate DSP ones. On platforms - // with I/O coherency support the framework will automatically skip - // cache operations where possible. - size_t n_bufs = dspqueue_buffers_init(bufs, src0, true, true); - - // Second buffer - // This is a buffer that the CPU writes and the DSP reads, so we'll - // need to flush CPU caches and invalidate DSP ones. On platforms - // with I/O coherency support the framework will automatically skip - // cache operations where possible. - n_bufs += dspqueue_buffers_init(&bufs[n_bufs], src1, true, true); - - // Third buffer(nullable) - // This is a buffer that the CPU writes and the DSP reads, so we'll - // need to flush CPU caches and invalidate DSP ones. On platforms - // with I/O coherency support the framework will automatically skip - // cache operations where possible. - n_bufs += dspqueue_buffers_init(&bufs[n_bufs], src2, true, true); - - // Final buffer = Output Activations. We'll handle DSP - // Second buffer = Output Activations. We'll handle DSP - // cache maintenance in the response message but need to flush - // CPU caches to ensure any previously written dirty lines are - // written out before writes from the DSP start. - n_bufs += dspqueue_buffers_init(&bufs[n_bufs], dst, true, false); - - // Primary DSP session from the src0 tensor - auto * sess = get_session_from_tensor(src0); - - if (opt_verbose) { - hex_print_op_info(op, sess, req.flags); - if (opt_verbose > 1) { - hex_dump_dspbuf(src0, &bufs[0]); - if (src1) { - hex_dump_dspbuf(src1, &bufs[1]); - hex_dump_dspbuf(dst, &bufs[2]); - } else { - hex_dump_dspbuf(dst, &bufs[1]); - } - } - } - - if ((opt_opmask & HTP_OPMASK_QUEUE)) { - sess->enqueue(req, bufs, n_bufs, opt_opsync); - } - - t2 = ggml_time_us(); - - if (src2) { - HEX_PROFILE( - "ggml-hex: %s %s %s %u:%u:%u:%u x %s %u:%u:%u:%u x %s %u:%u:%u:%u -> %s %u:%u:%u:%u : op-usec %u op-cycles " - "%u op-pkts %u (%f) call-usec %llu\n", - sess->name.c_str(), ggml_op_name(op->op), src0->name, (uint32_t) src0->ne[0], (uint32_t) src0->ne[1], - (uint32_t) src0->ne[2], (uint32_t) src0->ne[3], src1->name, (uint32_t) src1->ne[0], (uint32_t) src1->ne[1], - (uint32_t) src1->ne[2], (uint32_t) src1->ne[3], src2->name, (uint32_t) src2->ne[0], (uint32_t) src2->ne[1], - (uint32_t) src2->ne[2], (uint32_t) src2->ne[3], dst->name, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], - (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], sess->prof_usecs, sess->prof_cycles, sess->prof_pkts, - (float) sess->prof_cycles / sess->prof_pkts, (unsigned long long) t2 - t1); - } else { - HEX_PROFILE( - "ggml-hex: %s %s %s %u:%u:%u:%u x %s %u:%u:%u:%u -> %s %u:%u:%u:%u : op-usec %u op-cycles %u op-pkts %u " - "(%f) call-usec %llu\n", - sess->name.c_str(), ggml_op_name(op->op), src0->name, (uint32_t) src0->ne[0], (uint32_t) src0->ne[1], - (uint32_t) src0->ne[2], (uint32_t) src0->ne[3], src1->name, (uint32_t) src1->ne[0], (uint32_t) src1->ne[1], - (uint32_t) src1->ne[2], (uint32_t) src1->ne[3], dst->name, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], - (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], sess->prof_usecs, sess->prof_cycles, sess->prof_pkts, - (float) sess->prof_cycles / sess->prof_pkts, (unsigned long long) t2 - t1); - } + return n_bufs; } static const char * ggml_backend_hexagon_name(ggml_backend_t backend) { @@ -2896,7 +2343,7 @@ static void ggml_backend_hexagon_free(ggml_backend_t backend) { } static inline bool op_reuse_src1(const ggml_tensor * op1, const ggml_tensor * op0) { - return (op0 && op0->src[1] == op1->src[1]); + return (op0 && op0->src[1] == op1->src[1] && ggml_is_quantized(op0->src[0]->type) && ggml_is_quantized(op1->src[1]->type)); } static inline bool is_compute_op(ggml_tensor *node) @@ -2946,43 +2393,50 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg switch (node->op) { case GGML_OP_MUL_MAT: - ggml_hexagon_mul_mat(node, flags); + if (ggml_is_quantized(node->src[0]->type)) { + ggml_hexagon_dispatch_op>(sess, node, flags); + } else { + ggml_hexagon_dispatch_op>(sess, node, flags); + } prev_quant_op = node; break; case GGML_OP_MUL_MAT_ID: - ggml_hexagon_mul_mat_id(node, flags); + if (ggml_is_quantized(node->src[0]->type)) { + ggml_hexagon_dispatch_op>(sess, node, flags); + } else { + ggml_hexagon_dispatch_op>(sess, node, flags); + } prev_quant_op = node; break; case GGML_OP_MUL: case GGML_OP_ADD: case GGML_OP_SUB: - ggml_hexagon_binary(node, flags); + ggml_hexagon_dispatch_op>(sess, node, flags); break; case GGML_OP_ADD_ID: - ggml_hexagon_add_id(node, flags); + ggml_hexagon_dispatch_op>(sess, node, flags); break; case GGML_OP_RMS_NORM: - ggml_hexagon_unary(node, flags); + ggml_hexagon_dispatch_op(sess, node, flags); break; case GGML_OP_UNARY: - if (ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) { - ggml_hexagon_unary(node, flags); - } else if (ggml_get_unary_op(node) == GGML_UNARY_OP_GELU) { - ggml_hexagon_unary(node, flags); + if ((ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) || + (ggml_get_unary_op(node) == GGML_UNARY_OP_GELU)) { + ggml_hexagon_dispatch_op(sess, node, flags); } break; case GGML_OP_GLU: if ((ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU) || - (ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI)) { - ggml_hexagon_unary(node, flags); + (ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI)) { + ggml_hexagon_dispatch_op(sess, node, flags); } break; case GGML_OP_SOFT_MAX: - ggml_hexagon_unary(node, flags); + ggml_hexagon_dispatch_op(sess, node, flags); break; case GGML_OP_ROPE: - ggml_hexagon_rope(node, flags); + ggml_hexagon_dispatch_op(sess, node, flags); break; default: @@ -3111,8 +2565,8 @@ static void ggml_backend_hexagon_graph_optimize(ggml_backend_t backend, ggml_cgr // and perform the reorder over the fused nodes. after the reorder is done, we unfuse for (int i = 0; i < n; i++) { node_info node = { - /*.node =*/ gf->nodes[i], - /*.fused =*/ {}, + /*.node =*/gf->nodes[i], + /*.fused =*/{}, }; // fuse only ops that start with these operations @@ -3263,9 +2717,38 @@ static ggml_backend_buffer_type_t ggml_backend_hexagon_device_get_repack_buffer_ return &sess->repack_buffer_type; } +static bool ggml_hexagon_supported_buffer(ggml_hexagon_session *sess, const struct ggml_tensor * t) { + if (t && t->buffer) { + if (ggml_backend_buffer_is_hexagon(t->buffer) == false) return false; // not our buffer + if (ggml_backend_hexagon_buffer_get_sess(t->buffer) != sess) return false; // wrong session + } + return true; +} + +static bool ggml_hexagon_supported_buffers(ggml_hexagon_session *sess, const struct ggml_tensor * t) { + // all srcs & dsts must be mapped to the same session + if (!ggml_hexagon_supported_buffer(sess, t)) { + return false; + } + + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (!ggml_hexagon_supported_buffer(sess, t->src[i])) { + return false; + } + } + + return true; +} + static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { auto sess = static_cast(dev->context); + // all srcs & dsts must be mapped to the same session + if (!ggml_hexagon_supported_buffers(sess, op)) { + ggml_hexagon_dump_op_supp(sess->name, op, false); + return false; + } + bool supp = false; switch (op->op) { case GGML_OP_NONE: @@ -3303,20 +2786,21 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons break; case GGML_OP_UNARY: - if (ggml_get_unary_op(op) == GGML_UNARY_OP_SILU) { - supp = ggml_hexagon_supported_activations(sess, op); + { + const auto unary_op = ggml_get_unary_op(op); + if (unary_op == GGML_UNARY_OP_SILU || unary_op == GGML_UNARY_OP_GELU) { + supp = ggml_hexagon_supported_activations(sess, op); + } + break; } - else if (ggml_get_unary_op(op) == GGML_UNARY_OP_GELU){ - supp = ggml_hexagon_supported_activations(sess, op); - } - break; - case GGML_OP_GLU: - if ((ggml_get_glu_op(op) == GGML_GLU_OP_SWIGLU) || (ggml_get_glu_op(op) == GGML_GLU_OP_SWIGLU_OAI) ) { - supp = ggml_hexagon_supported_activations(sess, op); + { + const auto glu_op = ggml_get_glu_op(op); + if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI)) { + supp = ggml_hexagon_supported_activations(sess, op); + } + break; } - break; - case GGML_OP_ROPE: supp = ggml_hexagon_supported_rope(sess, op); break; @@ -3325,26 +2809,8 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons break; } - if (opt_verbose) { - char dims[64 * GGML_MAX_SRC]; - char strides[64 * GGML_MAX_SRC]; - char types[16 * GGML_MAX_SRC]; - char buffs[64 * GGML_MAX_SRC]; - char names[64 * GGML_MAX_SRC]; - - hex_format_op_dims(dims, op); - hex_format_op_strides(strides, op); - hex_format_op_types(types, op); - hex_format_op_buffs(buffs, op); - hex_format_op_names(names, op); - - HEX_VERBOSE("ggml-hex: %s device-supports-op %s : %s : %s : %s : %s : %s : (%d)\n", sess->name.c_str(), - ggml_op_name(op->op), names, dims, types, strides, buffs, (int) supp); - } - + ggml_hexagon_dump_op_supp(sess->name, op, supp); return supp; - - GGML_UNUSED(dev); } static bool ggml_backend_hexagon_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { @@ -3413,7 +2879,7 @@ ggml_hexagon_registry::ggml_hexagon_registry(ggml_backend_reg_t reg) { } } - if(opt_arch < 75) { + if (opt_arch < 75) { opt_ndev = 1; GGML_LOG_WARN("ggml-hex: forcing ndev to 1 for SoCs archs lower than v75.\n"); } @@ -3422,11 +2888,11 @@ ggml_hexagon_registry::ggml_hexagon_registry(ggml_backend_reg_t reg) { // Create devices / sessions for (size_t i = 0; i < opt_ndev; i++) { - devices[i].iface = ggml_backend_hexagon_device_i; - devices[i].reg = reg; + devices[i].iface = ggml_backend_hexagon_device_i; + devices[i].reg = reg; try { devices[i].context = new ggml_hexagon_session(i, &devices[i]); - } catch (std::exception const &exc) { + } catch (const std::exception & exc) { GGML_LOG_ERROR("ggml-hex: failed to create device/session %zu\n", i); devices[i].context = nullptr; } diff --git a/ggml/src/ggml-hexagon/htp-utils.h b/ggml/src/ggml-hexagon/htp-utils.h index 1a48f5dcbd..7bbae3a0b7 100644 --- a/ggml/src/ggml-hexagon/htp-utils.h +++ b/ggml/src/ggml-hexagon/htp-utils.h @@ -8,6 +8,7 @@ extern "C" { #include #include #include +#include #include /* Offset to differentiate HLOS and Hexagon error codes. diff --git a/ggml/src/ggml-hexagon/htp/act-ops.c b/ggml/src/ggml-hexagon/htp/act-ops.c index 586b5c1f92..7e488456ee 100644 --- a/ggml/src/ggml-hexagon/htp/act-ops.c +++ b/ggml/src/ggml-hexagon/htp/act-ops.c @@ -263,7 +263,8 @@ static void unary_gelu_fp32_per_thread(const struct htp_tensor * src0, struct htp_spad * dst_spad, uint32_t nth, uint32_t ith, - uint32_t src0_nrows_per_thread) { + uint32_t src0_nrows_per_thread, + dma_queue * dma_queue) { htp_act_preamble2; uint64_t t1, t2; @@ -271,6 +272,8 @@ static void unary_gelu_fp32_per_thread(const struct htp_tensor * src0, const size_t src0_row_size = nb01; const size_t dst_row_size = nb1; + const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN); + const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN); const uint32_t src0_nrows = ne01 * ne02 * ne03; @@ -282,60 +285,81 @@ static void unary_gelu_fp32_per_thread(const struct htp_tensor * src0, return; } - int is_aligned = 1; - int opt_path = 0; - if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) { - is_aligned = 0; - FARF(HIGH, "silu-f32: unaligned addresses in elementwise op, possibly slower execution\n"); - } - if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) { - opt_path = 1; + const uint8_t * data_src0 = (const uint8_t *) src0->data; + uint8_t * data_dst = (uint8_t *) dst->data; + + uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); + uint8_t * dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread); + + // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0 + size_t src0_spad_half_size = src0_spad->size_per_thread / 2; + size_t dst_spad_half_size = dst_spad->size_per_thread / 2; + + // In gelu = x*sigmoid(x*1.702) + const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block + + if (BLOCK == 0) { + FARF(ERROR, "gelu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", + src0_spad->size_per_thread, src0_row_size_aligned); + return; } - const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; - uint8_t * restrict data_dst = (uint8_t *) dst->data; + // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379 + for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { + const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); - uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size); - uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size); + // Dummy DMA transation for sequencing (interleaving dst,src,dst,...) + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)), + dst_row_size, dst_row_size_aligned, 0); + + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)), + src0_row_size_aligned, src0_row_size, block_size); + } - const int BLOCK = 8; for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) { - const uint32_t block_end = MIN(ir + BLOCK, src0_end_row); + const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); - // Prefetch next block - if (block_end < src0_end_row) { - const float * restrict prefetch_ptr = (float *) (data_src0 + (block_end * src0_row_size)); - htp_l2fetch(prefetch_ptr, 1, block_end * src0_row_size, src0_row_size); - } + float* dst_spad = (float *) dma_queue_pop(dma_queue).src; + float* src0_spad = (float *) dma_queue_pop(dma_queue).dst; - // Process rows in current block - for (uint32_t ib = ir; ib < block_end; ib++) { - const float * restrict src0 = (float *) (data_src0 + (ib * src0_row_size)); - float * restrict dst = (float *) (data_dst + (ib * dst_row_size)); + for (uint32_t ib = 0; ib < block_size; ib++) { + const float* src0_spad_ptr = src0_spad + ib * (src0_row_size_aligned / sizeof(float)); + float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float)); // gelu = x * sigmoid(1.702 * x) // current implementation - if (1 == opt_path) { - hvx_mul_scalar_f32((const uint8_t *) src0, (float) 1.702, (uint8_t *) src0_spad_data, ne0); - hvx_fast_sigmoid_f32((const uint8_t *) src0_spad_data, (uint8_t *) src0_spad_data, ne0); - hvx_mul_f32_opt((const uint8_t *) src0, src0_spad_data, (uint8_t *) dst, ne0); - } else { - hvx_mul_scalar_f32( (const uint8_t *) src0, (float)1.702, (uint8_t *) src0_spad_data, ne0); - hvx_sigmoid_f32((const uint8_t *) src0_spad_data, (uint8_t *) src0_spad_data, ne0); - hvx_mul_f32((const uint8_t *) src0, src0_spad_data, (uint8_t *) dst, ne0); - } + hvx_mul_scalar_f32((const uint8_t *) src0_spad_ptr, (float) 1.702, (uint8_t *) dst_spad_ptr, ne0); + hvx_fast_sigmoid_f32((const uint8_t *) dst_spad_ptr, (uint8_t *) dst_spad_ptr, ne0); + hvx_mul_f32_opt((const uint8_t *) src0_spad_ptr, (uint8_t *) dst_spad_ptr, (uint8_t *) dst_spad_ptr, ne0); + } + + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), + dst_row_size, dst_row_size_aligned, block_size); + + // prefetch N+2 loop iteration if any + const uint32_t pref_block = (ir + BLOCK * 2); + if (pref_block < src0_end_row) { + const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block); + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)), + src0_row_size_aligned, src0_row_size, pref_block_size); } } + dma_queue_flush(dma_queue); + t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "gelu-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, ne00, ne01, ne02, + FARF(HIGH, "gelu-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } static void unary_gelu_fp32(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = (struct htp_ops_context *) data; unary_gelu_fp32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread); + octx->src0_nrows_per_thread, octx->ctx->dma[i]); } @@ -468,21 +492,45 @@ static int execute_op_activations_fp32(struct htp_ops_context * octx) { const uint32_t n_threads = octx->n_threads; const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; - const size_t src0_row_size = src0->nb[1]; - const size_t src1_row_size = src1->ne[0] ? src1->nb[1] : src0->nb[1]; - const size_t dst_row_size = dst->nb[1]; + size_t src0_row_size = src0->nb[1]; + size_t src1_row_size = src1->nb[1]; // zero bytes if src1 is not used + size_t dst_row_size = dst->nb[1]; + const bool src1_valid = src1->ne[0]; + if (!src1_valid) { + src1_row_size = src0_row_size; + } + + const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN); + const size_t src1_row_size_aligned = htp_round_up(src1_row_size, VLEN); + const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN); // VTCM scratchpads for all tensors // N rows per thread, padded to HVX vector size - octx->dst_spad.size = htp_round_up(dst_row_size, 128) * octx->n_threads; - octx->src0_spad.size = htp_round_up(src0_row_size, 128) * octx->n_threads; - octx->src1_spad.size = htp_round_up(src1_row_size, 128) * octx->n_threads; - size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; + size_t spad_size_per_row = (src0_row_size_aligned + src1_row_size_aligned) + dst_row_size_aligned; + size_t vtcm_row_per_thread = (octx->ctx->vtcm_size)/ (n_threads* spad_size_per_row); + + // Make sure the reserved vtcm size is sufficient + if(vtcm_row_per_thread ==0){ + FARF(ERROR, "act-%s : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", op_type, octx->ctx->vtcm_size, + spad_size_per_row * n_threads); + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.size_per_thread = src0_row_size_aligned * vtcm_row_per_thread; + octx->src1_spad.size_per_thread = src1_row_size_aligned * vtcm_row_per_thread; + octx->dst_spad.size_per_thread = dst_row_size_aligned * vtcm_row_per_thread; + + octx->dst_spad.size = n_threads* octx->dst_spad.size_per_thread; + octx->src0_spad.size = n_threads* octx->src0_spad.size_per_thread; + octx->src1_spad.size = n_threads* octx->src1_spad.size_per_thread; + + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; if (src1->ne[0]) { - FARF(HIGH, - "%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", + FARF(HIGH, "%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size); @@ -492,20 +540,8 @@ static int execute_op_activations_fp32(struct htp_ops_context * octx) { octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size); } - // Make sure the reserved vtcm size is sufficient - if (octx->ctx->vtcm_size < spad_size) { - FARF(ERROR, "act-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, - spad_size); - return HTP_STATUS_VTCM_TOO_SMALL; - } - - octx->src0_spad.data = octx->ctx->vtcm_base; - octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; - octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; - if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { uint32_t n_jobs = MIN(n_threads, src0_nrows); - octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; worker_pool_run_func(octx->ctx->worker_pool, act_op_func, octx, n_jobs); } diff --git a/ggml/src/ggml-hexagon/htp/htp-dma.c b/ggml/src/ggml-hexagon/htp/htp-dma.c index 10c54b45ee..880c4542a0 100644 --- a/ggml/src/ggml-hexagon/htp/htp-dma.c +++ b/ggml/src/ggml-hexagon/htp/htp-dma.c @@ -34,12 +34,12 @@ dma_queue * dma_queue_create(size_t capacity) { q->desc = (hexagon_udma_descriptor_type1_t *) memalign(64, capacity * sizeof(hexagon_udma_descriptor_type1_t)); memset(q->desc, 0, capacity * sizeof(hexagon_udma_descriptor_type1_t)); - q->dst = (void **) memalign(4, capacity * sizeof(void *)); - memset(q->dst, 0, capacity * sizeof(void *)); + q->dptr = (dma_ptr *) memalign(4, capacity * sizeof(dma_ptr)); + memset(q->dptr, 0, capacity * sizeof(dma_ptr)); q->tail = &q->desc[capacity - 1]; - if (!q->desc && !q->dst) { + if (!q->desc && !q->dptr) { FARF(ERROR, "%s: failed to allocate DMA queue items\n", __FUNCTION__); return NULL; } @@ -54,16 +54,10 @@ void dma_queue_delete(dma_queue * q) { return; } free(q->desc); - free(q->dst); + free(q->dptr); free(q); } void dma_queue_flush(dma_queue * q) { - while (1) { - uint32_t s = dmwait() & 0x3; - if (s == HEXAGON_UDMA_DM0_STATUS_IDLE) { - break; - } - } - q->tail = NULL; + while (dma_queue_pop(q).dst != NULL) ; } diff --git a/ggml/src/ggml-hexagon/htp/htp-dma.h b/ggml/src/ggml-hexagon/htp/htp-dma.h index 7d3fc4078c..32fd06e7d4 100644 --- a/ggml/src/ggml-hexagon/htp/htp-dma.h +++ b/ggml/src/ggml-hexagon/htp/htp-dma.h @@ -11,10 +11,15 @@ extern "C" { #endif +typedef struct { + void *dst; + const void *src; +} dma_ptr; + typedef struct { hexagon_udma_descriptor_type1_t * desc; // descriptor pointers hexagon_udma_descriptor_type1_t * tail; // tail pointer - void ** dst; // dst pointers + dma_ptr * dptr; // dst/src pointers uint32_t push_idx; uint32_t pop_idx; uint32_t capacity; @@ -49,13 +54,20 @@ static inline unsigned int dmwait(void) { return ret; } -static inline bool dma_queue_push(dma_queue * q, - void * dst, - const void * src, - size_t dst_row_size, - size_t src_row_size, - size_t nrows) { +static inline dma_ptr dma_make_ptr(void *dst, const void *src) +{ + dma_ptr p = { dst, src }; + return p; +} + +static inline bool dma_queue_push(dma_queue * q, + dma_ptr dptr, + size_t dst_row_size, + size_t src_row_size, + size_t width, // width in bytes. number of bytes to transfer per row + size_t nrows) { if (((q->push_idx + 1) & q->idx_mask) == q->pop_idx) { + FARF(ERROR, "dma-push: queue full\n"); return false; } @@ -75,18 +87,18 @@ static inline bool dma_queue_push(dma_queue * q, #endif desc->order = 0; desc->dstate = HEXAGON_UDMA_DESC_DSTATE_INCOMPLETE; - desc->src = (void *) src; - desc->dst = (void *) dst; + desc->src = (void *) dptr.src; + desc->dst = (void *) dptr.dst; desc->allocation = 0; desc->padding = 0; - desc->roiwidth = src_row_size; + desc->roiwidth = width; desc->roiheight = nrows; desc->srcstride = src_row_size; desc->dststride = dst_row_size; desc->srcwidthoffset = 0; desc->dstwidthoffset = 0; - q->dst[q->push_idx] = dst; + q->dptr[q->push_idx] = dptr; dmlink(q->tail, desc); q->tail = desc; @@ -96,9 +108,28 @@ static inline bool dma_queue_push(dma_queue * q, return true; } -static inline uint8_t * dma_queue_pop(dma_queue * q) { +static inline bool dma_queue_push_ddr_to_vtcm(dma_queue * q, + dma_ptr dptr, + size_t dst_row_size, + size_t src_row_size, + size_t nrows) { + return dma_queue_push(q, dptr, dst_row_size, src_row_size, src_row_size, nrows); +} + + +static inline bool dma_queue_push_vtcm_to_ddr(dma_queue * q, + dma_ptr dptr, + size_t dst_row_size, + size_t src_row_size, + size_t nrows) { + return dma_queue_push(q, dptr, dst_row_size, src_row_size, dst_row_size, nrows); +} + +static inline dma_ptr dma_queue_pop(dma_queue * q) { + dma_ptr dptr = { NULL }; + if (q->push_idx == q->pop_idx) { - return NULL; + return dptr; } hexagon_udma_descriptor_type1_t * desc = &q->desc[q->pop_idx]; @@ -112,11 +143,11 @@ static inline uint8_t * dma_queue_pop(dma_queue * q) { // FARF(ERROR, "dma-pop: waiting for DMA : %u\n", q->pop_idx); } - uint8_t * dst = (uint8_t *) q->dst[q->pop_idx]; + dptr = q->dptr[q->pop_idx]; // FARF(ERROR, "dma-pop: i %u dst %p\n", q->pop_idx, dst); q->pop_idx = (q->pop_idx + 1) & q->idx_mask; - return dst; + return dptr; } #ifdef __cplusplus diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.h b/ggml/src/ggml-hexagon/htp/hvx-utils.h index 566048297d..d2d5d23636 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.h @@ -980,8 +980,6 @@ static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t * int step_of_1 = num_elems >> 5; int remaining = num_elems - step_of_1 * VLEN_FP32; - assert(remaining == 0); - const HVX_Vector * restrict v_src = (HVX_Vector *) src; HVX_Vector * restrict v_dst = (HVX_Vector *) dst; @@ -996,8 +994,16 @@ static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t * for (int i = 0; i < step_of_1; i++) { v_dst[i] = hvx_vec_fast_sigmoid_fp32_guard(v_src[i], one, max_exp, min_exp); } -} + if (remaining > 0) { + const float * srcf = ((const float *) src) + step_of_1* VLEN_FP32; + float * dstf = (float *) dst + step_of_1*VLEN_FP32; + + HVX_Vector in = *(HVX_UVector *) srcf; + HVX_Vector out = hvx_vec_fast_sigmoid_fp32_guard(in, one, max_exp, min_exp); + hvx_vec_store_u((void *) dstf, remaining * SIZEOF_FP32, out); + } +} static inline void hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems){ int step_of_1 = num_elems >> 5; // divby 32, because 32 float = 128 bytes per HVX vector diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 656c369d0a..fb5508a560 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -299,7 +299,8 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que ctx->n_threads = n_hvx; for (int i = 0; i < ctx->n_threads; i++) { - ctx->dma[i] = dma_queue_create(HTP_SPAD_SRC0_NROWS * 2); + // see discussion https://github.com/ggml-org/llama.cpp/pull/18151#discussion_r2632388541 + ctx->dma[i] = dma_queue_create(64); } // init worker pool diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index 0c9188244d..f14523d485 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -1127,13 +1127,13 @@ static void matmul(struct htp_matmul_type * mt, if (is0 >= HTP_SPAD_SRC0_NROWS) { break; } - dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size, + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), src0_row_size_padded, src0_row_size, 2); } // Process src0 rows for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { - const uint8_t * ss0 = dma_queue_pop(dma_queue); + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; #pragma unroll(2) for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) { @@ -1146,7 +1146,7 @@ static void matmul(struct htp_matmul_type * mt, const int pr0 = (ir0 + HTP_SPAD_SRC0_NROWS); const int is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS; if (pr0 < src0_end_row_x2) { - dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size, + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size), src0_row_size_padded, src0_row_size, 2); } } @@ -1155,9 +1155,9 @@ static void matmul(struct htp_matmul_type * mt, if (src0_end_row != src0_end_row_x2) { uint32_t ir0 = src0_end_row_x2; const int is0 = (ir0 - src0_start_row); - dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size, + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), src0_row_size_padded, src0_row_size, 1); - const uint8_t * ss0 = dma_queue_pop(dma_queue); + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; #pragma unroll(2) for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) { @@ -1229,20 +1229,20 @@ static void matvec(struct htp_matmul_type * mt, if (is0 >= HTP_SPAD_SRC0_NROWS) { break; } - dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size, + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), src0_row_size_padded, src0_row_size, 2); } // Process src0 rows for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { - const uint8_t * ss0 = dma_queue_pop(dma_queue); + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; mt->vec_dot_rx2(ne00, &tmp[ir0 - src0_start_row], ss0, src0_row_size_padded, src1_col); // Prefetch next (n + spad_nrows) row const uint32_t pr0 = (ir0 + HTP_SPAD_SRC0_NROWS); const uint32_t is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS; if (pr0 < src0_end_row_x2) { - dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size, + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size), src0_row_size_padded, src0_row_size, 2); } } @@ -1251,9 +1251,9 @@ static void matvec(struct htp_matmul_type * mt, if (src0_end_row != src0_end_row_x2) { const uint32_t ir0 = src0_end_row_x2; const uint32_t is0 = (ir0 - src0_start_row); - dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size, + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), src0_row_size_padded, src0_row_size, 1); - const uint8_t * ss0 = dma_queue_pop(dma_queue); + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; mt->vec_dot(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); } @@ -1343,13 +1343,13 @@ static void matmul_id(struct htp_matmul_type * mt, if (is0 >= HTP_SPAD_SRC0_NROWS) { break; } - dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size, + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), src0_row_size_padded, src0_row_size, 2); } // Process src0 rows for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { - const uint8_t * ss0 = dma_queue_pop(dma_queue); + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; for (uint32_t cid = 0; cid < cne1; ++cid) { struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid); @@ -1368,7 +1368,7 @@ static void matmul_id(struct htp_matmul_type * mt, const int pr0 = (ir0 + HTP_SPAD_SRC0_NROWS); const int is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS; if (pr0 < src0_end_row_x2) { - dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size, + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size), src0_row_size_padded, src0_row_size, 2); } } @@ -1377,9 +1377,9 @@ static void matmul_id(struct htp_matmul_type * mt, if (src0_end_row != src0_end_row_x2) { uint32_t ir0 = src0_end_row_x2; const uint32_t is0 = (ir0 - src0_start_row); - dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size, + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), src0_row_size_padded, src0_row_size, 1); - const uint8_t * ss0 = dma_queue_pop(dma_queue); + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; for (uint32_t cid = 0; cid < cne1; ++cid) { struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid); @@ -1467,20 +1467,20 @@ static void matvec_id(struct htp_matmul_type * mt, if (is0 >= HTP_SPAD_SRC0_NROWS) { break; } - dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size, + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), src0_row_size_padded, src0_row_size, 2); } // Process src0 rows for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { - const uint8_t * ss0 = dma_queue_pop(dma_queue); + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col); // Prefetch next (n + spad_nrows) row const int pr0 = (ir0 + HTP_SPAD_SRC0_NROWS); const int is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS; if (pr0 < src0_end_row_x2) { - dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size, + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size), src0_row_size_padded, src0_row_size, 2); } } @@ -1489,9 +1489,9 @@ static void matvec_id(struct htp_matmul_type * mt, if (src0_end_row != src0_end_row_x2) { uint32_t ir0 = src0_end_row_x2; const uint32_t is0 = (ir0 - src0_start_row); - dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size, + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), src0_row_size_padded, src0_row_size, 1); - const uint8_t * ss0 = dma_queue_pop(dma_queue); + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col); } } diff --git a/ggml/src/ggml-hexagon/op-desc.h b/ggml/src/ggml-hexagon/op-desc.h new file mode 100644 index 0000000000..a1e8ddd8b9 --- /dev/null +++ b/ggml/src/ggml-hexagon/op-desc.h @@ -0,0 +1,153 @@ +#ifndef OP_DESC_H +#define OP_DESC_H + +#define GGML_COMMON_IMPL_CPP +#include "ggml-backend-impl.h" +#include "ggml-common.h" + +#include +#include + +struct op_desc { + char strides[64 * GGML_MAX_SRC]; + char dims[64 * GGML_MAX_SRC]; + char types[16 * GGML_MAX_SRC]; + char buffs[64 * GGML_MAX_SRC]; + char names[64 * GGML_MAX_SRC]; + + int format_tensor_dims(char * str, const struct ggml_tensor * t) { + if (t->ne[2] == 1 && t->ne[3] == 1) { + return sprintf(str, "%d:%d", (int) t->ne[0], (int) t->ne[1]); + } else { + return sprintf(str, "%d:%d:%d:%d", (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]); + } + } + + void format_op_dims(char * str, const struct ggml_tensor * t) { + char * p = str; + + // append src0 and src1 (if any) + if (t->src[0]) { + p += format_tensor_dims(p, t->src[0]); + + for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) { + p += sprintf(p, " x "); + p += format_tensor_dims(p, t->src[i]); + } + + p += sprintf(p, " -> "); + } + + // format self dims separately for better visual alignment + char self[64]; + format_tensor_dims(self, t); + + p += sprintf(p, "%s", self); + } + + int format_tensor_strides(char * str, const struct ggml_tensor * t) { + const char * c = ggml_is_contiguous(t) ? "" : "!"; + + if (t->ne[2] == 1 && t->ne[3] == 1) { + return sprintf(str, "%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], c); + } else { + return sprintf(str, "%zu:%zu:%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], (size_t) t->nb[2], (size_t) t->nb[3], c); + } + } + + void format_op_strides(char * str, const struct ggml_tensor * t) { + char * p = str; + + // append src0 and src1 (if any) + if (t->src[0]) { + p += format_tensor_strides(p, t->src[0]); + + for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) { + p += sprintf(p, " x "); + p += format_tensor_strides(p, t->src[i]); + } + + p += sprintf(p, " -> "); + } + + // format self dims separately for better visual alignment + char self[64]; + format_tensor_strides(self, t); + + p += sprintf(p, "%s", self); + } + + void format_op_types(char * str, const struct ggml_tensor * t) { + char * p = str; + + // append src0 and src1 (if any) + if (t->src[0]) { + p += sprintf(p, "%s", ggml_type_name(t->src[0]->type)); + + for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) { + p += sprintf(p, " x "); + p += sprintf(p, "%s", ggml_type_name(t->src[i]->type)); + } + + p += sprintf(p, " -> "); + } + + p += sprintf(p, "%s", ggml_type_name(t->type)); + } + + const char * tensor_buff_name(const struct ggml_tensor * t) { + if (t->buffer) { + return ggml_backend_buffer_name(t->buffer); + } + return "NONE"; + } + + void format_op_buffs(char * str, const struct ggml_tensor * t) { + char * p = str; + + // append src0 and src1 (if any) + if (t->src[0]) { + p += sprintf(p, "%s", tensor_buff_name(t->src[0])); + + for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) { + p += sprintf(p, " x "); + p += sprintf(p, "%s", tensor_buff_name(t->src[i])); + } + + p += sprintf(p, " -> "); + } + + p += sprintf(p, "%s", tensor_buff_name(t)); + } + + void format_op_names(char * str, const struct ggml_tensor * t) { + char * p = str; + + // append src0 and src1 (if any) + if (t->src[0]) { + p += sprintf(p, "%s", t->src[0]->name); + + for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) { + p += sprintf(p, " x "); + p += sprintf(p, "%s", t->src[i]->name); + } + + p += sprintf(p, " -> "); + } + + p += sprintf(p, "%s", t->name); + } + + void format(const ggml_tensor * op) { + format_op_dims(dims, op); + format_op_strides(strides, op); + format_op_types(types, op); + format_op_buffs(buffs, op); + format_op_names(names, op); + } + + op_desc() {} + op_desc(const ggml_tensor * op) { format(op); } +}; + +#endif // OP_DESC_H diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 0d37587f60..639715537b 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -494,6 +494,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0; cl_kernel kernel_mul_mat_q4_0_f32_8x_flat; cl_kernel kernel_convert_block_q4_0_noshuffle; + cl_kernel kernel_restore_block_q4_0_noshuffle; cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; cl_kernel kernel_mul_mv_q6_K_f32; cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat; @@ -634,6 +635,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_transpose_32; cl_kernel kernel_transpose_32_16; cl_kernel kernel_transpose_16; + cl_kernel kernel_transpose_16_buf; cl_kernel kernel_transpose_16_4x1; cl_mem A_s_d_max; // max scale buffer size for transpose @@ -806,6 +808,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); CL_CHECK((backend_ctx->kernel_convert_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err)); @@ -2004,7 +2007,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_transpose_32_16 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_32_16", &err), err)); CL_CHECK((backend_ctx->kernel_transpose_32 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_32", &err), err)); CL_CHECK((backend_ctx->kernel_transpose_16 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_16", &err), err)); - CL_CHECK((backend_ctx->kernel_transpose_16_4x1 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_16_4x1", &err), err)); + CL_CHECK((backend_ctx->kernel_transpose_16_buf = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_16_buf", &err), err)); + CL_CHECK((backend_ctx->kernel_transpose_16_4x1 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_16_4x1", &err), err)); GGML_LOG_CONT("."); } @@ -3933,6 +3937,91 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, if (tensor->type == GGML_TYPE_Q4_0) { ggml_tensor_extra_cl_q4_0 * extra = (ggml_tensor_extra_cl_q4_0 *)tensor->extra; +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + cl_int err; + cl_kernel kernel; + + cl_int M = tensor->ne[1]; // ne01 + cl_int K = tensor->ne[0]; // ne00 + + GGML_ASSERT(K % 32 == 0); + GGML_ASSERT(M % 4 == 0); + + size_t size_q = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*ggml_blck_size(tensor->type)/2; + size_t size_d = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t); + GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); + + cl_mem buf_trans_q; + cl_mem buf_trans_d; + + CL_CHECK((buf_trans_q = clCreateBuffer(context, CL_MEM_READ_WRITE, + size_q, NULL, &err), err)); + CL_CHECK((buf_trans_d = clCreateBuffer(context, CL_MEM_READ_WRITE, + size_d, NULL, &err), err)); + + kernel = backend_ctx->kernel_transpose_16_buf; + + // transpose q back + cl_int stride_k_q = K/4; + size_t local_size_q[3] = {64, 1, 1}; + size_t global_size_q[3] = {(size_t)M, (size_t)stride_k_q, 1}; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_int), &M)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &stride_k_q)); + + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_size_q, local_size_q, 0, NULL, NULL)); + + // transpose scales back + cl_int stride_k_d = K/32; + size_t local_size_d[3] = {64, 1, 1}; + size_t global_size_d[3] = {(size_t)M, (size_t)stride_k_d, 1}; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_int), &M)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &stride_k_d)); + + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_size_d, local_size_d, 0, NULL, NULL)); + + // unpack + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + kernel = backend_ctx->kernel_restore_block_q4_0_noshuffle; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_uchar), &mask_F0)); + + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, NULL)); + + // read back to host + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + + CL_CHECK(clReleaseMemObject(data_device)); + CL_CHECK(clReleaseMemObject(buf_trans_q)); + CL_CHECK(clReleaseMemObject(buf_trans_d)); + + return; + } +#endif + cl_int err; cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, ggml_nbytes(tensor), NULL, &err); diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index b26f9c5fb2..513a4d3e28 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -117,6 +117,27 @@ kernel void kernel_convert_block_q4_0_noshuffle( } } +kernel void kernel_restore_block_q4_0_noshuffle( + global uchar * src_q, + global half * src_d, + global struct block_q4_0 * dst, + uchar mask_0F, + uchar mask_F0 +) { + global struct block_q4_0 * b = (global struct block_q4_0 *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK4_0/2*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + + b->d = *d; + for (int i = 0; i < QK4_0/4; ++i) { + uchar x0 = q[i + 0 ] ; + uchar x1 = q[i + QK4_0/4]; + + b->qs[2*i + 0] = convert_uchar((x0 & mask_0F) | ((x1 & mask_0F) << 4)); + b->qs[2*i + 1] = convert_uchar(((x0 & mask_F0) >> 4) | (x1 & mask_F0)); + } +} + //------------------------------------------------------------------------------ // block_mxfp4 //------------------------------------------------------------------------------ diff --git a/ggml/src/ggml-opencl/kernels/transpose.cl b/ggml/src/ggml-opencl/kernels/transpose.cl index 536dd560a9..1279b6531b 100644 --- a/ggml/src/ggml-opencl/kernels/transpose.cl +++ b/ggml/src/ggml-opencl/kernels/transpose.cl @@ -44,6 +44,19 @@ kernel void kernel_transpose_16_4x1( write_imageh(output, i * rows + j, (half4)(temp0, temp1, temp2, temp3)); } +// Transpose treating each element as 16-bit using buffer +kernel void kernel_transpose_16_buf( + global const ushort * input, + global ushort * output, + const int ldi, + const int ldo +) { + const int x = get_global_id(0); + const int y = get_global_id(1); + + output[x*ldo + y] = input[y*ldi + x]; +} + // 32-bit transpose, loading/storing a 4x4 tile of elements kernel void kernel_transpose_32( __read_only image1d_buffer_t input, diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 13cf1f5f9d..e7890a5ee9 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -571,6 +571,10 @@ static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) { return ctx->base_ptr; } +static bool ggml_backend_buffer_is_rpc(ggml_backend_buffer_t buffer) { + return buffer->iface.free_buffer == ggml_backend_rpc_buffer_free_buffer; +} + static rpc_tensor serialize_tensor(const ggml_tensor * tensor) { rpc_tensor result; if (!tensor) { @@ -580,7 +584,7 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) { result.id = reinterpret_cast(tensor); result.type = tensor->type; - if (tensor->buffer) { + if (tensor->buffer && ggml_backend_buffer_is_rpc(tensor->buffer)) { ggml_backend_buffer_t buffer = tensor->buffer; ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; result.buffer = ctx != nullptr ? ctx->remote_ptr : 0; @@ -664,10 +668,6 @@ static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, con RPC_STATUS_ASSERT(status); } -static bool ggml_backend_buffer_is_rpc(ggml_backend_buffer_t buffer) { - return buffer->iface.free_buffer == ggml_backend_rpc_buffer_free_buffer; -} - static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { if (ggml_backend_buffer_is_rpc(src->buffer)) { // check if src and dst are on the same server diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index c2adca9cba..4f50d378cd 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -379,18 +379,18 @@ enum FaCodePath { }; struct vk_fa_pipeline_state { - vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, FaCodePath path, bool aligned, bool f32acc) - : HSK(HSK), HSV(HSV), small_rows(small_rows), path(path), aligned(aligned), f32acc(f32acc) {} + vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc) + : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc) {} uint32_t HSK, HSV; - bool small_rows; + bool small_rows, small_cache; FaCodePath path; bool aligned; bool f32acc; bool operator<(const vk_fa_pipeline_state &b) const { - return std::tie(HSK, HSV, small_rows, path, aligned, f32acc) < - std::tie(b.HSK, b.HSV, b.small_rows, b.path, b.aligned, b.f32acc); + return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc) < + std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc); } }; @@ -731,7 +731,7 @@ struct vk_device_struct { vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16, pipeline_rope_norm_f32_f16; vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16; - vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16; + vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16, pipeline_rope_multi_f32_f16; vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16; vk_pipeline pipeline_argsort_f32[num_argsort_pipelines]; vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines]; @@ -2582,10 +2582,10 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events static constexpr uint32_t flash_attention_num_small_rows = 32; static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; -static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv) { +static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv, bool small_cache) { if (hsv >= 192) { return 2; - } else if ((hsv | hsk) & 8) { + } else if ((hsv | hsk) & 8 || small_cache) { return 4; } else { return 8; @@ -2607,9 +2607,8 @@ static uint32_t get_fa_num_small_rows(FaCodePath path) { } } -static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) { +static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) { GGML_UNUSED(clamp); - GGML_UNUSED(hsv); if (path == FA_SCALAR) { if (small_rows) { @@ -2618,9 +2617,9 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3 if ((hsv | hsk) & 8) { // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not. - return {get_fa_scalar_num_large_rows(hsk, hsv), 64}; + return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 64}; } else { - return {get_fa_scalar_num_large_rows(hsk, hsv), 32}; + return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 32}; } } } @@ -2649,8 +2648,8 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3 return {64, 64}; } -static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows) { - return fa_rows_cols(path, hsk, hsv, 0, type, small_rows)[1]; +static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows, bool small_cache) { + return fa_rows_cols(path, hsk, hsv, 0, type, small_rows, small_cache)[1]; } static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) { @@ -2992,11 +2991,11 @@ static void ggml_vk_load_shaders(vk_device& device) { align, disable_robustness, require_full_subgroups, required_subgroup_size); }; - auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { - return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1}; + auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::array { + return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1}; }; - auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector { + auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::vector { // For large number of rows, 128 invocations seems to work best. // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we // can't use 256 for D==80. @@ -3006,7 +3005,7 @@ static void ggml_vk_load_shaders(vk_device& device) { uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1) ? scalar_flash_attention_workgroup_size : ((small_rows && (D % 32) == 0) ? 256 : 128); - auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows); + auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache); // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader. @@ -3021,21 +3020,22 @@ static void ggml_vk_load_shaders(vk_device& device) { uint32_t HSK = fa.first.HSK; \ uint32_t HSV = fa.first.HSV; \ bool small_rows = fa.first.small_rows; \ + bool small_cache = fa.first.small_cache; \ FaCodePath path = fa.first.path; \ bool aligned = fa.first.aligned; \ bool f32acc = fa.first.f32acc; \ if (path == FAPATH) { \ if (aligned) { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } \ } else { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } \ } \ } \ @@ -4077,6 +4077,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_rte_len, rope_norm_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_rte_len, rope_neox_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_rte_len, rope_multi_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); } else { ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); @@ -4085,6 +4086,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_len, rope_norm_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_len, rope_multi_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); } for (uint32_t i = 0; i < num_argsort_pipelines; ++i) { @@ -8006,11 +8008,11 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx } } -static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv) { +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool small_cache) { // Needs to be kept up to date on shader changes GGML_UNUSED(hsv); const uint32_t wg_size = scalar_flash_attention_workgroup_size; - const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv); + const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv, small_cache); const uint32_t Bc = scalar_flash_attention_Bc; const uint32_t tmpsh = wg_size * sizeof(float); @@ -8134,6 +8136,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx uint32_t workgroups_y = (uint32_t)neq2; uint32_t workgroups_z = (uint32_t)neq3; + const bool small_cache = nek1 < 1024; + // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga. // For coopmat2 FA, we always use the small size (which is still pretty large for gqa). uint32_t max_gqa; @@ -8141,7 +8145,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx case FA_SCALAR: case FA_COOPMAT1: // We may switch from coopmat1 to scalar, so use the scalar limit for both - max_gqa = get_fa_scalar_num_large_rows(HSK, HSV); + max_gqa = get_fa_scalar_num_large_rows(HSK, HSV, small_cache); break; case FA_COOPMAT2: max_gqa = get_fa_num_small_rows(FA_COOPMAT2); @@ -8175,7 +8179,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory if (path == FA_SCALAR && - !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV)) { + !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, small_cache)) { small_rows = true; } @@ -8191,7 +8195,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx v_stride /= 4; } - uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows); + uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows, small_cache); bool aligned = (KV % alignment) == 0 && // the "aligned" shader variant will forcibly align strides, for performance (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0; @@ -8203,7 +8207,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; - vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, path, aligned, f32acc); + vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc); vk_pipeline pipeline = nullptr; @@ -8680,6 +8684,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_rope_multi_f32; } + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_rope_multi_f32_f16; + } if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { return ctx->device->pipeline_rope_multi_f16; } @@ -13076,9 +13083,9 @@ static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const return false; } - // Only norm/neox shaders have the fusion code + // Only norm/neox/mrope shaders have the fusion code const int mode = ((const int32_t *) rope->op_params)[2]; - if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) { + if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX && mode != GGML_ROPE_TYPE_MROPE) { return false; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl index 9726b722d1..aacec98469 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl @@ -49,8 +49,8 @@ void rope_norm(const uint i0, const uint i1, rope_params p) { uint idst = i1*ne0 + i0; const uint ix = rope_a_coord(i0, i01, i02, p); - // Fusion optimization: ROPE + VIEW + SET_ROWS.. - // The rope output is viewed as a 1D tensor and offset based on a row index in data_i. + // Fusion optimization: ROPE + VIEW + SET_ROWS. + // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i. if (p.set_rows_stride != 0) { idst = i01*ne0 + i0; idst += rope_data_i[i02].x * p.set_rows_stride; @@ -91,7 +91,7 @@ void rope_neox(const uint i0, const uint i1, rope_params p) { uint idst = i1*ne0 + i0/2; const uint ix = rope_a_coord(i0/2, i01, i02, p); - // Fusion optimization: ROPE + VIEW + SET_ROWS.. + // Fusion optimization: ROPE + VIEW + SET_ROWS. // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i. if (p.set_rows_stride != 0) { idst = i01*ne0 + i0/2; @@ -132,9 +132,16 @@ void rope_multi(const uint i0, const uint i1, rope_params p) { const uint i01 = i1 % ne1; const uint i02 = i1 / ne1; - const uint idst = i1*ne0 + i0/2; + uint idst = i1*ne0 + i0/2; const uint ix = rope_a_coord(i0/2, i01, i02, p); + // Fusion optimization: ROPE + VIEW + SET_ROWS. + // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i. + if (p.set_rows_stride != 0) { + idst = i01*ne0 + i0/2; + idst += rope_data_i[i02].x * p.set_rows_stride; + } + if (i0 >= p.n_dims) { rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]); rope_data_d[idst + i0/2 + 1] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 1]); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 92ad3bcab1..e237a8e102 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -927,6 +927,8 @@ void process_shaders() { string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); + string_to_spv("rope_multi_f32_f16", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}}); + string_to_spv("rope_multi_f32_f16_rte", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index cab8f2901a..41d3bd4faf 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -181,6 +181,7 @@ class Keys: DIMENSION_COUNT = "{arch}.rope.dimension_count" DIMENSION_SECTIONS = "{arch}.rope.dimension_sections" FREQ_BASE = "{arch}.rope.freq_base" + FREQ_BASE_SWA = "{arch}.rope.freq_base_swa" SCALING_TYPE = "{arch}.rope.scaling.type" SCALING_FACTOR = "{arch}.rope.scaling.factor" SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor" @@ -354,6 +355,7 @@ class MODEL_ARCH(IntEnum): STARCODER = auto() REFACT = auto() BERT = auto() + MODERN_BERT = auto() NOMIC_BERT = auto() NOMIC_BERT_MOE = auto() NEO_BERT = auto() @@ -747,6 +749,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.STARCODER: "starcoder", MODEL_ARCH.REFACT: "refact", MODEL_ARCH.BERT: "bert", + MODEL_ARCH.MODERN_BERT: "modern-bert", MODEL_ARCH.NOMIC_BERT: "nomic-bert", MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe", MODEL_ARCH.NEO_BERT: "neo-bert", @@ -1367,6 +1370,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.CLS, MODEL_TENSOR.CLS_OUT, ], + MODEL_ARCH.MODERN_BERT: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.TOKEN_EMBD_NORM, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.CLS, + MODEL_TENSOR.CLS_OUT, + ], MODEL_ARCH.NOMIC_BERT: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD_NORM, diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 9e6ff3ac77..6a4a504f8d 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -774,8 +774,12 @@ class GGUFWriter: def add_shared_kv_layers(self, value: int) -> None: self.add_uint32(Keys.Attention.SHARED_KV_LAYERS.format(arch=self.arch), value) - def add_sliding_window_pattern(self, value: Sequence[bool]) -> None: - self.add_array(Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch), value) + def add_sliding_window_pattern(self, value: int | Sequence[bool]) -> None: + key = Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch) + if isinstance(value, int): + self.add_uint32(key, value) + else: + self.add_array(key, value) def add_dense_features_dims(self, dense:str, in_f:int, out_f:int) -> None: self.add_uint32(Keys.LLM.DENSE_FEAT_IN_SIZE.format(arch=self.arch, dense=dense), in_f) @@ -886,6 +890,9 @@ class GGUFWriter: def add_value_residual_mix_lora_rank(self, length: int) -> None: self.add_uint32(Keys.Attention.VALUE_RESIDUAL_MIX_LORA_RANK.format(arch=self.arch), length) + def add_rope_freq_base_swa(self, value: float) -> None: + self.add_float32(Keys.Rope.FREQ_BASE_SWA.format(arch=self.arch), value) + def add_gate_lora_rank(self, length: int) -> None: self.add_uint32(Keys.Attention.GATE_LORA_RANK.format(arch=self.arch), length) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 301aafa910..276720fcde 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -17,6 +17,7 @@ class TensorNameMap: "embed_tokens", # embeddinggemma "tok_embeddings", # llama-pth "embeddings.word_embeddings", # bert nomic-bert + "embeddings.tok_embeddings", # modern-bert "language_model.embedding.word_embeddings", # persimmon "wte", # gpt2 "transformer.embd.wte", # phi2 @@ -46,6 +47,7 @@ class TensorNameMap: MODEL_TENSOR.TOKEN_EMBD_NORM: ( "word_embeddings_layernorm", # bloom "embeddings.LayerNorm", # bert + "embeddings.norm", # modern-bert "emb_ln", # nomic-bert "transformer.norm", # openelm "rwkv.blocks.0.pre_ln", # rwkv @@ -75,6 +77,7 @@ class TensorNameMap: "head.out", # wavtokenizer "lm_head", # llama4 "model.transformer.ff_out", # llada + "head.decoder", # modern-bert ), MODEL_TENSOR.DENSE_2_OUT: ( "dense_2_out", # embeddinggemma @@ -104,6 +107,7 @@ class TensorNameMap: "backbone.final_layer_norm", # wavtokenizer "model.norm", # llama4 "model.transformer.ln_f", # llada + "final_norm", # modern-bert "model.norm", # cogvlm ), @@ -151,6 +155,7 @@ class TensorNameMap: "model.layers.{bid}.input_layernorm", # llama4 "layers.{bid}.input_layernorm", # embeddinggemma "transformer_encoder.{bid}.attention_norm", # neobert + "layers.{bid}.attn_norm", # modern-bert "model.layers.{bid}.operator_norm", # lfm2 "model.transformer.blocks.{bid}.attn_norm", # llada "layers.{bid}.input_layernorm", # qwen3-embedding @@ -187,6 +192,7 @@ class TensorNameMap: "encoder.layers.{bid}.self_attention.query_key_value", # chatglm "transformer.layers.{bid}.attn.qkv_proj", # openelm "transformer_encoder.{bid}.qkv", # neobert + "layers.{bid}.attn.Wqkv", # modern-bert "model.layers.{bid}.self_attn.language_expert_query_key_value", # cogvlm ), @@ -261,6 +267,7 @@ class TensorNameMap: "model.layers.{bid}.self_attn.linear_attn", # deci "layers.{bid}.attention.wo", # llama-pth "encoder.layer.{bid}.attention.output.dense", # bert + "layers.{bid}.attn.Wo", # modern-bert "transformer.layer.{bid}.attention.out_lin", # distillbert "transformer.h.{bid}.attn.out_proj", # gpt-j "language_model.encoder.layers.{bid}.self_attention.dense", # persimmon @@ -344,6 +351,7 @@ class TensorNameMap: "layers.{bid}.post_attention_layernorm", # qwen3-embedding "model.layers.{bid}.feedforward_layernorm", # apertus "model.layers.{bid}.pre_mlp_layernorm", # kormo + "layers.{bid}.mlp_norm" # modern-bert ), # Pre feed-forward norm @@ -407,6 +415,7 @@ class TensorNameMap: "layers.{bid}.mlp.up_proj", # embeddinggemma "layers.{bid}.feed_forward.w3", # llama-pth "encoder.layer.{bid}.intermediate.dense", # bert + "layers.{bid}.mlp.Wi", # modern-bert "transformer.layer.{bid}.ffn.lin1", # distillbert "transformer.h.{bid}.mlp.fc_in", # gpt-j "transformer.h.{bid}.mlp.linear_3", # refact @@ -521,6 +530,7 @@ class TensorNameMap: "layers.{bid}.mlp.down_proj", # embeddinggemma "layers.{bid}.feed_forward.w2", # llama-pth "encoder.layer.{bid}.output.dense", # bert + "layers.{bid}.mlp.Wo", # modern-bert "transformer.layer.{bid}.ffn.lin2", # distillbert "transformer.h.{bid}.mlp.fc_out", # gpt-j "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon @@ -1122,6 +1132,7 @@ class TensorNameMap: "classifier.dense", # roberta "pre_classifier", # distillbert "dense", # neobert + "head.dense", # modern-bert ), MODEL_TENSOR.CLS_OUT: ( diff --git a/gguf-py/gguf/utility.py b/gguf-py/gguf/utility.py index 4918ae971a..154351d8ed 100644 --- a/gguf-py/gguf/utility.py +++ b/gguf-py/gguf/utility.py @@ -110,7 +110,6 @@ class SafetensorRemote: """ BASE_DOMAIN = "https://huggingface.co" - ALIGNMENT = 8 # bytes @classmethod def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]: @@ -204,9 +203,6 @@ class SafetensorRemote: # Calculate the data start offset data_start_offset = 8 + metadata_length - alignment = SafetensorRemote.ALIGNMENT - if data_start_offset % alignment != 0: - data_start_offset += alignment - (data_start_offset % alignment) # Check if we have enough data to read the metadata if len(raw_data) < 8 + metadata_length: @@ -298,7 +294,6 @@ class SafetensorsLocal: Custom parsing gives a bit more control over the memory usage. The official safetensors library doesn't expose file ranges. """ - ALIGNMENT = 8 # bytes tensors: dict[str, LocalTensor] @@ -316,9 +311,6 @@ class SafetensorsLocal: raise ValueError(f"Failed to parse safetensors metadata as JSON: {e}") data_start_offset = f.tell() - alignment = self.ALIGNMENT - if data_start_offset % alignment != 0: - data_start_offset += alignment - (data_start_offset % alignment) tensors: dict[str, LocalTensor] = {} for name, meta in metadata.items(): diff --git a/scripts/snapdragon/adb/run-cli.sh b/scripts/snapdragon/adb/run-cli.sh index cc5e47c2d6..8a3053c859 100755 --- a/scripts/snapdragon/adb/run-cli.sh +++ b/scripts/snapdragon/adb/run-cli.sh @@ -18,17 +18,17 @@ model="Llama-3.2-3B-Instruct-Q4_0.gguf" device="HTP0" [ "$D" != "" ] && device="$D" -verbose= -[ "$V" != "" ] && verbose="GGML_HEXAGON_VERBOSE=$V" - experimental= [ "$E" != "" ] && experimental="GGML_HEXAGON_EXPERIMENTAL=$E" +verbose= +[ "$V" != "" ] && verbose="GGML_HEXAGON_VERBOSE=$V" cli_opts="$cli_opts -v" + sched= [ "$SCHED" != "" ] && sched="GGML_SCHED_DEBUG=2" cli_opts="$cli_opts -v" profile= -[ "$PROF" != "" ] && profile="GGML_HEXAGON_PROFILE=$PROF GGML_HEXAGON_OPSYNC=1" +[ "$PROF" != "" ] && profile="GGML_HEXAGON_PROFILE=$PROF GGML_HEXAGON_OPSYNC=1" cli_opts="$cli_opts -v" opmask= [ "$OPMASK" != "" ] && opmask="GGML_HEXAGON_OPMASK=$OPMASK" @@ -45,9 +45,9 @@ adb $adbserial shell " \ cd $basedir; ulimit -c unlimited; \ LD_LIBRARY_PATH=$basedir/$branch/lib \ ADSP_LIBRARY_PATH=$basedir/$branch/lib \ - $verbose $experimental $sched $opmask $profile $nhvx $ndev \ - ./$branch/bin/llama-completion --no-mmap -m $basedir/../gguf/$model \ - --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \ - --ctx-size 8192 --batch-size 128 -ctk q8_0 -ctv q8_0 -fa on \ - -ngl 99 --device $device $cli_opts $@ \ + $verbose $experimental $sched $opmask $profile $nhvx $ndev \ + ./$branch/bin/llama-cli --no-mmap -m $basedir/../gguf/$model \ + --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \ + --ctx-size 8192 --batch-size 128 -fa on \ + -ngl 99 --device $device $cli_opts $@ \ " diff --git a/scripts/snapdragon/adb/run-completion.sh b/scripts/snapdragon/adb/run-completion.sh new file mode 100755 index 0000000000..bb7ba5e671 --- /dev/null +++ b/scripts/snapdragon/adb/run-completion.sh @@ -0,0 +1,53 @@ +#!/bin/sh +# + +# Basedir on device +basedir=/data/local/tmp/llama.cpp + +cli_opts= + +branch=. +[ "$B" != "" ] && branch=$B + +adbserial= +[ "$S" != "" ] && adbserial="-s $S" + +model="Llama-3.2-3B-Instruct-Q4_0.gguf" +[ "$M" != "" ] && model="$M" + +device="HTP0" +[ "$D" != "" ] && device="$D" + +experimental= +[ "$E" != "" ] && experimental="GGML_HEXAGON_EXPERIMENTAL=$E" + +verbose= +[ "$V" != "" ] && verbose="GGML_HEXAGON_VERBOSE=$V" cli_opts="$cli_opts -v" + +sched= +[ "$SCHED" != "" ] && sched="GGML_SCHED_DEBUG=2" cli_opts="$cli_opts -v" + +profile= +[ "$PROF" != "" ] && profile="GGML_HEXAGON_PROFILE=$PROF GGML_HEXAGON_OPSYNC=1" cli_opts="$cli_opts -v" + +opmask= +[ "$OPMASK" != "" ] && opmask="GGML_HEXAGON_OPMASK=$OPMASK" + +nhvx= +[ "$NHVX" != "" ] && nhvx="GGML_HEXAGON_NHVX=$NHVX" + +ndev= +[ "$NDEV" != "" ] && ndev="GGML_HEXAGON_NDEV=$NDEV" + +set -x + +adb $adbserial shell " \ + cd $basedir; ulimit -c unlimited; \ + LD_LIBRARY_PATH=$basedir/$branch/lib \ + ADSP_LIBRARY_PATH=$basedir/$branch/lib \ + $verbose $experimental $sched $opmask $profile $nhvx $ndev \ + ./$branch/bin/llama-completion --no-mmap -m $basedir/../gguf/$model \ + --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \ + --ctx-size 8192 --batch-size 128 -fa on \ + -ngl 99 -no-cnv --device $device $cli_opts $@ \ +" diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 4192af7c0c..4ca8974916 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -90,6 +90,7 @@ add_library(llama models/mamba.cpp models/minicpm3.cpp models/minimax-m2.cpp + models/modern-bert.cpp models/mpt.cpp models/nemotron-h.cpp models/nemotron.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index d0eaf317f7..80f44ae1bf 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -20,6 +20,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_STARCODER, "starcoder" }, { LLM_ARCH_REFACT, "refact" }, { LLM_ARCH_BERT, "bert" }, + { LLM_ARCH_MODERN_BERT, "modern-bert" }, { LLM_ARCH_NOMIC_BERT, "nomic-bert" }, { LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" }, { LLM_ARCH_NEO_BERT, "neo-bert" }, @@ -204,6 +205,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_GATE_LORA_RANK, "%s.attention.gate_lora_rank" }, { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" }, { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, + { LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, "%s.attention.sliding_window_pattern" }, { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, { LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" }, { LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" }, @@ -214,6 +216,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, + { LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" }, { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, @@ -778,6 +781,20 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_CLS, LLM_TENSOR_CLS_OUT, }; + case LLM_ARCH_MODERN_BERT: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_TOKEN_EMBD_NORM, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_CLS, + LLM_TENSOR_CLS_OUT, + }; case LLM_ARCH_JINA_BERT_V2: return { LLM_TENSOR_TOKEN_EMBD, diff --git a/src/llama-arch.h b/src/llama-arch.h index 6cbf9b1f89..a53bc39d18 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -24,6 +24,7 @@ enum llm_arch { LLM_ARCH_STARCODER, LLM_ARCH_REFACT, LLM_ARCH_BERT, + LLM_ARCH_MODERN_BERT, LLM_ARCH_NOMIC_BERT, LLM_ARCH_NOMIC_BERT_MOE, LLM_ARCH_NEO_BERT, @@ -208,6 +209,7 @@ enum llm_kv { LLM_KV_ATTENTION_GATE_LORA_RANK, LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, LLM_KV_ATTENTION_SLIDING_WINDOW, + LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, LLM_KV_ATTENTION_SCALE, LLM_KV_ATTENTION_OUTPUT_SCALE, LLM_KV_ATTENTION_TEMPERATURE_LENGTH, @@ -218,6 +220,7 @@ enum llm_kv { LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_SECTIONS, LLM_KV_ROPE_FREQ_BASE, + LLM_KV_ROPE_FREQ_BASE_SWA, LLM_KV_ROPE_SCALE_LINEAR, LLM_KV_ROPE_SCALING_TYPE, LLM_KV_ROPE_SCALING_FACTOR, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index b32674ab76..6507cdf25b 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -491,23 +491,22 @@ llama_context::llama_context( } llama_context::~llama_context() { - // FIXME this currently results in a use-after-free bug if the model is freed before the context - // if (!model.hparams.no_alloc) { - // for (size_t i = 0; i < backend_ptrs.size(); ++i) { - // ggml_backend_t backend = backend_ptrs[i]; - // ggml_backend_buffer_type_t buft = backend_buft[i]; + if (!model.hparams.no_alloc) { + for (size_t i = 0; i < backend_ptrs.size(); ++i) { + ggml_backend_t backend = backend_ptrs[i]; + ggml_backend_buffer_type_t buft = backend_buft[i]; - // const size_t size_exp = backend_buf_exp_size[i]; - // const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend); - // if (size_exp == size_act) { - // LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n", - // __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0)); - // } else { - // LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n", - // __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0)); - // } - // } - // } + const size_t size_exp = backend_buf_exp_size[i]; + const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend); + if (size_exp == size_act) { + LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n", + __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0)); + } else { + LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n", + __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0)); + } + } + } ggml_opt_free(opt_ctx); } diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 33a76dba40..5003b4fbf5 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -462,6 +462,29 @@ namespace GGUFMeta { return get_key_or_arr(llm_kv(kid), result, n, required); } + bool llama_model_loader::get_key_or_arr(enum llm_kv kid, uint32_t & result, bool required) { + const std::string key = llm_kv(kid); + + const int id = gguf_find_key(meta.get(), key.c_str()); + + if (id < 0) { + if (required) { + throw std::runtime_error(format("key not found in model: %s", key.c_str())); + } + return false; + } + + // throw and error if type is an array + if (gguf_get_kv_type(meta.get(), id) == GGUF_TYPE_ARRAY) { + if (required) { + throw std::runtime_error(format("expected scalar, found array for key: %s", key.c_str())); + } + return false; + } + + return get_key(key, result, required); + } + // TODO: this is not very clever - figure out something better template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); diff --git a/src/llama-model-loader.h b/src/llama-model-loader.h index 0380c92fde..d13299ad3f 100644 --- a/src/llama-model-loader.h +++ b/src/llama-model-loader.h @@ -131,6 +131,8 @@ struct llama_model_loader { template bool get_key_or_arr(enum llm_kv kid, T & result, uint32_t n, bool required = true); + bool get_key_or_arr(enum llm_kv kid, uint32_t & result, bool required = true); + std::string get_arch_name() const; enum llm_arch get_arch() const; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 1ce2364d63..c86937ad00 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -31,12 +31,14 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_17M: return "17M"; case LLM_TYPE_22M: return "22M"; case LLM_TYPE_33M: return "33M"; + case LLM_TYPE_47M: return "47M"; case LLM_TYPE_60M: return "60M"; case LLM_TYPE_70M: return "70M"; case LLM_TYPE_80M: return "80M"; case LLM_TYPE_109M: return "109M"; case LLM_TYPE_137M: return "137M"; case LLM_TYPE_140M: return "140M"; + case LLM_TYPE_149M: return "149M"; case LLM_TYPE_160M: return "160M"; case LLM_TYPE_190M: return "190M"; case LLM_TYPE_220M: return "220M"; @@ -46,6 +48,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_335M: return "335M"; case LLM_TYPE_350M: return "350M"; case LLM_TYPE_360M: return "360M"; + case LLM_TYPE_395M: return "395M"; case LLM_TYPE_410M: return "410M"; case LLM_TYPE_450M: return "450M"; case LLM_TYPE_475M: return "475M"; @@ -875,6 +878,34 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_MODERN_BERT: + { + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + if (found_swa && hparams.n_swa > 0) { + uint32_t swa_period = 3; + hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC; + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa); + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + } + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); + + switch (hparams.n_layer) { + case 12: + type = LLM_TYPE_47M; break; // granite-embedding-small + case 22: + type = LLM_TYPE_149M; break; // modern-bert-base + case 28: + type = LLM_TYPE_395M; break; // modern-bert-large + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_JINA_BERT_V2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -3155,6 +3186,37 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); } } break; + case LLM_ARCH_MODERN_BERT: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + for(int i = 0; i < n_layer; ++i) { + auto& layer = layers[i]; + + if ( i != 0 ) { + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + } else{ + // layer 0 uses identity + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + } + + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3 * n_embd }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, 2 * n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + } + + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + + } break; case LLM_ARCH_NEO_BERT: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -5181,9 +5243,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_group = hparams.ssm_n_group; const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head; - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; - const int64_t n_ff_shexp = hparams.n_ff_shexp; - // embeddings tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -5235,6 +5294,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); } else { if (n_expert != 0) { + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + const int64_t n_ff_shexp = hparams.n_ff_shexp; + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert}, 0); layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert }, 0); @@ -7089,6 +7151,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_NEO_BERT: case LLM_ARCH_WAVTOKENIZER_DEC: + case LLM_ARCH_MODERN_BERT: case LLM_ARCH_GEMMA_EMBEDDING: case LLM_ARCH_DREAM: case LLM_ARCH_LLADA: @@ -7248,6 +7311,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_MODERN_BERT: + { + llm = std::make_unique>(*this, params); + } break; case LLM_ARCH_NEO_BERT: { llm = std::make_unique(*this, params); @@ -7821,6 +7888,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_DBRX: case LLM_ARCH_BERT: case LLM_ARCH_JINA_BERT_V3: + case LLM_ARCH_MODERN_BERT: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_STABLELM: diff --git a/src/llama-model.h b/src/llama-model.h index c6eb953188..7f560d462f 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -24,12 +24,14 @@ enum llm_type { LLM_TYPE_17M, LLM_TYPE_22M, LLM_TYPE_33M, + LLM_TYPE_47M, LLM_TYPE_60M, LLM_TYPE_70M, LLM_TYPE_80M, LLM_TYPE_109M, LLM_TYPE_137M, LLM_TYPE_140M, + LLM_TYPE_149M, LLM_TYPE_160M, LLM_TYPE_190M, LLM_TYPE_220M, @@ -39,6 +41,7 @@ enum llm_type { LLM_TYPE_335M, LLM_TYPE_350M, LLM_TYPE_360M, + LLM_TYPE_395M, LLM_TYPE_410M, LLM_TYPE_450M, LLM_TYPE_475M, diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 7b01a2edfe..cd4092ca07 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1878,7 +1878,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "jina-v2-es" || tokenizer_pre == "jina-v2-de" || tokenizer_pre == "a.x-4.0" || - tokenizer_pre == "mellum") { + tokenizer_pre == "mellum" || + tokenizer_pre == "modern-bert" ) { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; } else if ( tokenizer_pre == "jina-v1-en" || @@ -2528,6 +2529,13 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { for (const auto * token : {"", "", "<|endoftext|>"}) { _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false); } + } else if (_contains_any(model_name, {"modern-bert"})) { + if (token_to_id.count("[MASK]") == 0 ) { + LLAMA_LOG_WARN("%s: Mask token missing in vocab!\n", __func__); + } + else { + _set_token_attr("[MASK]", LLAMA_TOKEN_ATTR_LSTRIP, true); + } } } } diff --git a/src/models/models.h b/src/models/models.h index ffb36acc61..53a5810659 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -327,6 +327,11 @@ struct llm_build_mistral3 : public llm_graph_context { llm_build_mistral3(const llama_model & model, const llm_graph_params & params); }; +template +struct llm_build_modern_bert : public llm_graph_context { + llm_build_modern_bert(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_mpt : public llm_graph_context { llm_build_mpt(const llama_model & model, const llm_graph_params & params); }; diff --git a/src/models/modern-bert.cpp b/src/models/modern-bert.cpp new file mode 100644 index 0000000000..c7809bdedf --- /dev/null +++ b/src/models/modern-bert.cpp @@ -0,0 +1,126 @@ +#include "models.h" + +template +llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + ggml_tensor * inp_pos = build_inp_pos(); + + // construct input embeddings (token, type, position) + inpL = build_inp_embd(model.tok_embd); + cb(inpL, "inp_embd", -1); + + // embed layer norm + inpL = build_norm(inpL, model.tok_norm, nullptr, LLM_NORM, -1); + cb(inpL, "inp_norm", -1); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + auto * inp_attn = build_attn_inp_no_cache(); + + for (int il = 0; il < n_layer; ++il) { + float freq_base_l = 0.0f; + + if constexpr (iswa) { + freq_base_l = model.get_rope_freq_base(cparams, il); + } else { + freq_base_l = freq_base; + } + + cur = inpL; + + // attention layer norm + if (model.layers[il].attn_norm) { + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM, il); + cb(cur, "attn_norm", il); + } + + // self attention + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + const size_t type_size = ggml_type_size(cur->type); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*type_size, cur->nb[1], 0*type_size*(n_embd)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*type_size, cur->nb[1], 1*type_size*(n_embd)); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*type_size, cur->nb[1], 1*type_size*(n_embd + n_embd_gqa)); + + // RoPE + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + cb(cur, "kqv_out", il); + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // re-add the layer input + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + // attention layer norm + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_GEGLU, LLM_FFN_SEQ, il); + + // attentions bypass the intermediate layer + cur = ggml_add(ctx0, cur, ffn_inp); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM, -1); + cb(cur, "final_norm_out", -1); + + if (hparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { + // extracting cls token + cur = ggml_view_1d(ctx0, cur, hparams.n_embd, 0); + cb(cur, "cls_pooled_embd", -1); + } + + cb(cur, "res_embd", -1); + res->t_embd = cur; + ggml_build_forward_expand(gf, cur); +} + +// Explicit template instantiations +template struct llm_build_modern_bert; +template struct llm_build_modern_bert; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index f4801ee3da..a9424708b6 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2329,11 +2329,13 @@ struct test_set_rows : public test_case { struct test_rope_set_rows : public test_case { const ggml_type type; const ggml_type type_idx; - const std::array ne; + const std::array ne_a; int mode; + const int n_ctx{512}; + const int n_dims{128}; std::string vars() override { - return VARS_TO_STR4(type, type_idx, ne, mode); + return VARS_TO_STR4(type, type_idx, ne_a, mode); } std::string op_desc(ggml_tensor * t) override { @@ -2345,24 +2347,51 @@ struct test_rope_set_rows : public test_case { test_rope_set_rows(ggml_type type, ggml_type type_idx, - std::array ne, + std::array ne_a, int mode) - : type(type), type_idx(type_idx), ne(ne), mode(mode) {} + : type(type), type_idx(type_idx), ne_a(ne_a), mode(mode) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * src = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], ne[1], ne[2], 1); - ggml_set_name(src, "src"); + ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne_a[0], ne_a[1], ne_a[2], 1); + ggml_set_name(a, "a"); - ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne[2]); + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; - ggml_tensor * rope = ggml_rope(ctx, src, pos, ne[0], mode); + ggml_tensor * pos; + if (is_mrope || is_vision) { + pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2] * 4); + } else { + pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2]); + } + ggml_set_name(pos, "pos"); - ggml_tensor * view = ggml_view_2d(ctx, rope, ne[0] * ne[1], ne[2], rope->nb[2], 0); + float fs = 1.4245f; + float ef = 0.7465f; + float af = 1.4245f; + ggml_tensor * freq = nullptr; - ggml_tensor * dst = ggml_new_tensor_4d(ctx, type, ne[0] * ne[1], ne[2] * ne[3], 1, 1); + ggml_tensor * rope = nullptr; + if (is_mrope) { + if (is_vision) { + GGML_ASSERT(n_dims/4 > 0); + int rope_sections[4] = {n_dims/4, n_dims/4, 0, 0}; // Vision-RoPE only use first two dimension for image (x, y) coordinate + rope = ggml_rope_multi(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + } else { + GGML_ASSERT(n_dims/3 > 0); + int rope_sections[4] = {n_dims/3, n_dims/3, n_dims/3, 0}; + rope = ggml_rope_multi(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + } + } else { + rope = ggml_rope(ctx, a, pos, ne_a[0], mode); + } + + ggml_tensor * view = ggml_view_2d(ctx, rope, ne_a[0] * ne_a[1], ne_a[2], rope->nb[2], 0); + + ggml_tensor * dst = ggml_new_tensor_4d(ctx, type, ne_a[0] * ne_a[1], ne_a[2] * ne_a[3], 1, 1); ggml_set_name(dst, "dst"); - ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, type_idx, ne[2], 1, 1); + ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, type_idx, ne_a[2], 1, 1); ggml_set_name(row_idxs, "row_idxs"); ggml_tensor * out = ggml_set_rows(ctx, dst, view, row_idxs); @@ -2373,14 +2402,26 @@ struct test_rope_set_rows : public test_case { void initialize_tensors(ggml_context * ctx) override { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { - if (t->type == GGML_TYPE_I64 || t->type == GGML_TYPE_I32) { + if (strcmp(t->name, "row_idxs") == 0) { if (ggml_is_view_op(t->op)) { continue; } - - init_set_rows_row_ids(t, ne[2]); + init_set_rows_row_ids(t, ne_a[2]); + } else if (t->type == GGML_TYPE_I32) { + // pos + const int num_pos_ids = (mode & GGML_ROPE_TYPE_MROPE) ? ne_a[2] * 4 : ne_a[2]; + std::vector data(num_pos_ids); + for (int i = 0; i < num_pos_ids; i++) { + data[i] = rand() % n_ctx; + } + ggml_backend_tensor_set(t, data.data(), 0, num_pos_ids * sizeof(int)); } else { - init_tensor_uniform(t); + if (t->ne[0] == n_dims/2) { + // frequency factors in the range [0.9f, 1.1f] + init_tensor_uniform(t, 0.9f, 1.1f); + } else { + init_tensor_uniform(t); + } } } } @@ -6854,10 +6895,12 @@ static std::vector> make_test_cases_eval() { } } - for (int mode : { GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX }) { + for (int mode : { GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX, GGML_ROPE_TYPE_MROPE, GGML_ROPE_TYPE_VISION }) { for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) { - test_cases.emplace_back(new test_rope_set_rows(type, GGML_TYPE_I64, { 128, 32, 1, 100 }, mode)); - test_cases.emplace_back(new test_rope_set_rows(type, GGML_TYPE_I64, { 128, 32, 512, 1 }, mode)); + for (int ne2 : {1, 8, 512}) { + test_cases.emplace_back(new test_rope_set_rows(type, GGML_TYPE_I64, { 128, 32, ne2, 1 }, mode)); + test_cases.emplace_back(new test_rope_set_rows(type, GGML_TYPE_I64, { 128, 32, ne2, 3 }, mode)); + } } } diff --git a/tests/test-grammar-llguidance.cpp b/tests/test-grammar-llguidance.cpp index 566b039a07..34746c200c 100644 --- a/tests/test-grammar-llguidance.cpp +++ b/tests/test-grammar-llguidance.cpp @@ -1196,6 +1196,9 @@ int main(int argc, const char ** argv) { test_sampler_chain(); + llama_free(ctx); + llama_model_free(model); + fprintf(stdout, "All tests passed.\n"); return 0; } diff --git a/tests/test-tokenizer-0.cpp b/tests/test-tokenizer-0.cpp index 59dda48772..37f8312c46 100644 --- a/tests/test-tokenizer-0.cpp +++ b/tests/test-tokenizer-0.cpp @@ -300,8 +300,8 @@ int main(int argc, char **argv) { fprintf(stderr, "%s : tokens written to '%s'\n", __func__, (fname_text + ".tokcpp").c_str()); } - llama_model_free(model); llama_free(ctx); + llama_model_free(model); llama_backend_free(); diff --git a/tests/test-tokenizer-1-bpe.cpp b/tests/test-tokenizer-1-bpe.cpp index b183da47f3..505dbfdb93 100644 --- a/tests/test-tokenizer-1-bpe.cpp +++ b/tests/test-tokenizer-1-bpe.cpp @@ -146,8 +146,8 @@ int main(int argc, char **argv) { } } - llama_model_free(model); llama_free(ctx); + llama_model_free(model); llama_backend_free(); diff --git a/tests/test-tokenizer-1-spm.cpp b/tests/test-tokenizer-1-spm.cpp index ba6e94ba8e..8e370d2c7b 100644 --- a/tests/test-tokenizer-1-spm.cpp +++ b/tests/test-tokenizer-1-spm.cpp @@ -116,8 +116,8 @@ int main(int argc, char ** argv) { } } - llama_model_free(model); llama_free(ctx); + llama_model_free(model); llama_backend_free(); diff --git a/tools/batched-bench/batched-bench.cpp b/tools/batched-bench/batched-bench.cpp index 2032a386bb..0f627c5ff6 100644 --- a/tools/batched-bench/batched-bench.cpp +++ b/tools/batched-bench/batched-bench.cpp @@ -55,6 +55,7 @@ int main(int argc, char ** argv) { if (ctx == NULL) { fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); + llama_model_free(model); return 1; } @@ -108,6 +109,8 @@ int main(int argc, char ** argv) { if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) { LOG_ERR("%s: llama_decode() failed\n", __func__); + llama_free(ctx); + llama_model_free(model); return 1; } } @@ -147,6 +150,8 @@ int main(int argc, char ** argv) { if (!decode_helper(ctx, batch, ctx_params.n_batch, false)) { LOG_ERR("%s: llama_decode() failed\n", __func__); + llama_free(ctx); + llama_model_free(model); return 1; } @@ -165,6 +170,8 @@ int main(int argc, char ** argv) { common_batch_add(batch, get_token_rand(), pp + 0, { 0 }, true); if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) { LOG_ERR("%s: llama_decode() failed\n", __func__); + llama_free(ctx); + llama_model_free(model); return 1; } llama_memory_seq_rm(mem, 0, pp, -1); @@ -184,6 +191,8 @@ int main(int argc, char ** argv) { if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) { LOG_ERR("%s: llama_decode() failed\n", __func__); + llama_free(ctx); + llama_model_free(model); return 1; } } @@ -200,6 +209,8 @@ int main(int argc, char ** argv) { if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) { LOG_ERR("%s: llama_decode() failed\n", __func__); + llama_free(ctx); + llama_model_free(model); return 1; } } diff --git a/tools/cli/README.md b/tools/cli/README.md index 1333ed77b7..7b8b8692e9 100644 --- a/tools/cli/README.md +++ b/tools/cli/README.md @@ -1 +1,187 @@ -TODO +# llama.cpp/tools/cli + +## Usage + + + + + +### Common params + +| Argument | Explanation | +| -------- | ----------- | +| `-h, --help, --usage` | print usage and exit | +| `--version` | show version and build info | +| `-cl, --cache-list` | show list of models in cache | +| `--completion-bash` | print source-able bash completion script for llama.cpp | +| `--verbose-prompt` | print a verbose prompt before generation (default: false) | +| `-t, --threads N` | number of CPU threads to use during generation (default: -1)
(env: LLAMA_ARG_THREADS) | +| `-tb, --threads-batch N` | number of threads to use during batch and prompt processing (default: same as --threads) | +| `-C, --cpu-mask M` | CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: "") | +| `-Cr, --cpu-range lo-hi` | range of CPUs for affinity. Complements --cpu-mask | +| `--cpu-strict <0\|1>` | use strict CPU placement (default: 0) | +| `--prio N` | set process/thread priority : low(-1), normal(0), medium(1), high(2), realtime(3) (default: 0) | +| `--poll <0...100>` | use polling level to wait for work (0 - no polling, default: 50) | +| `-Cb, --cpu-mask-batch M` | CPU affinity mask: arbitrarily long hex. Complements cpu-range-batch (default: same as --cpu-mask) | +| `-Crb, --cpu-range-batch lo-hi` | ranges of CPUs for affinity. Complements --cpu-mask-batch | +| `--cpu-strict-batch <0\|1>` | use strict CPU placement (default: same as --cpu-strict) | +| `--prio-batch N` | set process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: 0) | +| `--poll-batch <0\|1>` | use polling to wait for work (default: same as --poll) | +| `-c, --ctx-size N` | size of the prompt context (default: 0, 0 = loaded from model)
(env: LLAMA_ARG_CTX_SIZE) | +| `-n, --predict, --n-predict N` | number of tokens to predict (default: -1, -1 = infinity)
(env: LLAMA_ARG_N_PREDICT) | +| `-b, --batch-size N` | logical maximum batch size (default: 2048)
(env: LLAMA_ARG_BATCH) | +| `-ub, --ubatch-size N` | physical maximum batch size (default: 512)
(env: LLAMA_ARG_UBATCH) | +| `--keep N` | number of tokens to keep from the initial prompt (default: 0, -1 = all) | +| `--swa-full` | use full-size SWA cache (default: false)
[(more info)](https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
(env: LLAMA_ARG_SWA_FULL) | +| `-fa, --flash-attn [on\|off\|auto]` | set Flash Attention use ('on', 'off', or 'auto', default: 'auto')
(env: LLAMA_ARG_FLASH_ATTN) | +| `-p, --prompt PROMPT` | prompt to start generation with; for system message, use -sys | +| `--perf, --no-perf` | whether to enable internal libllama performance timings (default: false)
(env: LLAMA_ARG_PERF) | +| `-f, --file FNAME` | a file containing the prompt (default: none) | +| `-bf, --binary-file FNAME` | binary file containing the prompt (default: none) | +| `-e, --escape, --no-escape` | whether to process escapes sequences (\n, \r, \t, \', \", \\) (default: true) | +| `--rope-scaling {none,linear,yarn}` | RoPE frequency scaling method, defaults to linear unless specified by the model
(env: LLAMA_ARG_ROPE_SCALING_TYPE) | +| `--rope-scale N` | RoPE context scaling factor, expands context by a factor of N
(env: LLAMA_ARG_ROPE_SCALE) | +| `--rope-freq-base N` | RoPE base frequency, used by NTK-aware scaling (default: loaded from model)
(env: LLAMA_ARG_ROPE_FREQ_BASE) | +| `--rope-freq-scale N` | RoPE frequency scaling factor, expands context by a factor of 1/N
(env: LLAMA_ARG_ROPE_FREQ_SCALE) | +| `--yarn-orig-ctx N` | YaRN: original context size of model (default: 0 = model training context size)
(env: LLAMA_ARG_YARN_ORIG_CTX) | +| `--yarn-ext-factor N` | YaRN: extrapolation mix factor (default: -1.0, 0.0 = full interpolation)
(env: LLAMA_ARG_YARN_EXT_FACTOR) | +| `--yarn-attn-factor N` | YaRN: scale sqrt(t) or attention magnitude (default: -1.0)
(env: LLAMA_ARG_YARN_ATTN_FACTOR) | +| `--yarn-beta-slow N` | YaRN: high correction dim or alpha (default: -1.0)
(env: LLAMA_ARG_YARN_BETA_SLOW) | +| `--yarn-beta-fast N` | YaRN: low correction dim or beta (default: -1.0)
(env: LLAMA_ARG_YARN_BETA_FAST) | +| `-kvo, --kv-offload, -nkvo, --no-kv-offload` | whether to enable KV cache offloading (default: enabled)
(env: LLAMA_ARG_KV_OFFLOAD) | +| `--repack, -nr, --no-repack` | whether to enable weight repacking (default: enabled)
(env: LLAMA_ARG_REPACK) | +| `--no-host` | bypass host buffer allowing extra buffers to be used
(env: LLAMA_ARG_NO_HOST) | +| `-ctk, --cache-type-k TYPE` | KV cache data type for K
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_K) | +| `-ctv, --cache-type-v TYPE` | KV cache data type for V
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_V) | +| `-dt, --defrag-thold N` | KV cache defragmentation threshold (DEPRECATED)
(env: LLAMA_ARG_DEFRAG_THOLD) | +| `-np, --parallel N` | number of parallel sequences to decode (default: 1)
(env: LLAMA_ARG_N_PARALLEL) | +| `--mlock` | force system to keep model in RAM rather than swapping or compressing
(env: LLAMA_ARG_MLOCK) | +| `--mmap, --no-mmap` | whether to memory-map model (if disabled, slower load but may reduce pageouts if not using mlock) (default: enabled)
(env: LLAMA_ARG_MMAP) | +| `--numa TYPE` | attempt optimizations that help on some NUMA systems
- distribute: spread execution evenly over all nodes
- isolate: only spawn threads on CPUs on the node that execution started on
- numactl: use the CPU map provided by numactl
if run without this previously, it is recommended to drop the system page cache before using this
see https://github.com/ggml-org/llama.cpp/issues/1437
(env: LLAMA_ARG_NUMA) | +| `-dev, --device ` | comma-separated list of devices to use for offloading (none = don't offload)
use --list-devices to see a list of available devices
(env: LLAMA_ARG_DEVICE) | +| `--list-devices` | print list of available devices and exit | +| `-ot, --override-tensor =,...` | override tensor buffer type | +| `-cmoe, --cpu-moe` | keep all Mixture of Experts (MoE) weights in the CPU
(env: LLAMA_ARG_CPU_MOE) | +| `-ncmoe, --n-cpu-moe N` | keep the Mixture of Experts (MoE) weights of the first N layers in the CPU
(env: LLAMA_ARG_N_CPU_MOE) | +| `-ngl, --gpu-layers, --n-gpu-layers N` | max. number of layers to store in VRAM (default: -1)
(env: LLAMA_ARG_N_GPU_LAYERS) | +| `-sm, --split-mode {none,layer,row}` | how to split the model across multiple GPUs, one of:
- none: use one GPU only
- layer (default): split layers and KV across GPUs
- row: split rows across GPUs
(env: LLAMA_ARG_SPLIT_MODE) | +| `-ts, --tensor-split N0,N1,N2,...` | fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1
(env: LLAMA_ARG_TENSOR_SPLIT) | +| `-mg, --main-gpu INDEX` | the GPU to use for the model (with split-mode = none), or for intermediate results and KV (with split-mode = row) (default: 0)
(env: LLAMA_ARG_MAIN_GPU) | +| `-fit, --fit [on\|off]` | whether to adjust unset arguments to fit in device memory ('on' or 'off', default: 'on')
(env: LLAMA_ARG_FIT) | +| `-fitt, --fit-target MiB` | target margin per device for --fit option, default: 1024
(env: LLAMA_ARG_FIT_TARGET) | +| `-fitc, --fit-ctx N` | minimum ctx size that can be set by --fit option, default: 4096
(env: LLAMA_ARG_FIT_CTX) | +| `--check-tensors` | check model tensor data for invalid values (default: false) | +| `--override-kv KEY=TYPE:VALUE,...` | advanced option to override model metadata by key. to specify multiple overrides, either use comma-separated or repeat this argument.
types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false,tokenizer.ggml.add_eos_token=bool:false | +| `--op-offload, --no-op-offload` | whether to offload host tensor operations to device (default: true) | +| `--lora FNAME` | path to LoRA adapter (use comma-separated values to load multiple adapters) | +| `--lora-scaled FNAME:SCALE,...` | path to LoRA adapter with user defined scaling (format: FNAME:SCALE,...)
note: use comma-separated values | +| `--control-vector FNAME` | add a control vector
note: use comma-separated values to add multiple control vectors | +| `--control-vector-scaled FNAME:SCALE,...` | add a control vector with user defined scaling SCALE
note: use comma-separated values (format: FNAME:SCALE,...) | +| `--control-vector-layer-range START END` | layer range to apply the control vector(s) to, start and end inclusive | +| `-m, --model FNAME` | model path to load
(env: LLAMA_ARG_MODEL) | +| `-mu, --model-url MODEL_URL` | model download url (default: unused)
(env: LLAMA_ARG_MODEL_URL) | +| `-dr, --docker-repo [/][:quant]` | Docker Hub model repository. repo is optional, default to ai/. quant is optional, default to :latest.
example: gemma3
(default: unused)
(env: LLAMA_ARG_DOCKER_REPO) | +| `-hf, -hfr, --hf-repo /[:quant]` | Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.
mmproj is also downloaded automatically if available. to disable, add --no-mmproj
example: unsloth/phi-4-GGUF:q4_k_m
(default: unused)
(env: LLAMA_ARG_HF_REPO) | +| `-hfd, -hfrd, --hf-repo-draft /[:quant]` | Same as --hf-repo, but for the draft model (default: unused)
(env: LLAMA_ARG_HFD_REPO) | +| `-hff, --hf-file FILE` | Hugging Face model file. If specified, it will override the quant in --hf-repo (default: unused)
(env: LLAMA_ARG_HF_FILE) | +| `-hfv, -hfrv, --hf-repo-v /[:quant]` | Hugging Face model repository for the vocoder model (default: unused)
(env: LLAMA_ARG_HF_REPO_V) | +| `-hffv, --hf-file-v FILE` | Hugging Face model file for the vocoder model (default: unused)
(env: LLAMA_ARG_HF_FILE_V) | +| `-hft, --hf-token TOKEN` | Hugging Face access token (default: value from HF_TOKEN environment variable)
(env: HF_TOKEN) | +| `--log-disable` | Log disable | +| `--log-file FNAME` | Log to file
(env: LLAMA_LOG_FILE) | +| `--log-colors [on\|off\|auto]` | Set colored logging ('on', 'off', or 'auto', default: 'auto')
'auto' enables colors when output is to a terminal
(env: LLAMA_LOG_COLORS) | +| `-v, --verbose, --log-verbose` | Set verbosity level to infinity (i.e. log all messages, useful for debugging) | +| `--offline` | Offline mode: forces use of cache, prevents network access
(env: LLAMA_OFFLINE) | +| `-lv, --verbosity, --log-verbosity N` | Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:
- 0: generic output
- 1: error
- 2: warning
- 3: info
- 4: debug
(default: 3)

(env: LLAMA_LOG_VERBOSITY) | +| `--log-prefix` | Enable prefix in log messages
(env: LLAMA_LOG_PREFIX) | +| `--log-timestamps` | Enable timestamps in log messages
(env: LLAMA_LOG_TIMESTAMPS) | +| `-ctkd, --cache-type-k-draft TYPE` | KV cache data type for K for the draft model
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_K_DRAFT) | +| `-ctvd, --cache-type-v-draft TYPE` | KV cache data type for V for the draft model
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_V_DRAFT) | + + +### Sampling params + +| Argument | Explanation | +| -------- | ----------- | +| `--samplers SAMPLERS` | samplers that will be used for generation in the order, separated by ';'
(default: penalties;dry;top_n_sigma;top_k;typ_p;top_p;min_p;xtc;temperature) | +| `-s, --seed SEED` | RNG seed (default: -1, use random seed for -1) | +| `--sampler-seq, --sampling-seq SEQUENCE` | simplified sequence for samplers that will be used (default: edskypmxt) | +| `--ignore-eos` | ignore end of stream token and continue generating (implies --logit-bias EOS-inf) | +| `--temp N` | temperature (default: 0.8) | +| `--top-k N` | top-k sampling (default: 40, 0 = disabled)
(env: LLAMA_ARG_TOP_K) | +| `--top-p N` | top-p sampling (default: 0.9, 1.0 = disabled) | +| `--min-p N` | min-p sampling (default: 0.1, 0.0 = disabled) | +| `--top-nsigma N` | top-n-sigma sampling (default: -1.0, -1.0 = disabled) | +| `--xtc-probability N` | xtc probability (default: 0.0, 0.0 = disabled) | +| `--xtc-threshold N` | xtc threshold (default: 0.1, 1.0 = disabled) | +| `--typical N` | locally typical sampling, parameter p (default: 1.0, 1.0 = disabled) | +| `--repeat-last-n N` | last n tokens to consider for penalize (default: 64, 0 = disabled, -1 = ctx_size) | +| `--repeat-penalty N` | penalize repeat sequence of tokens (default: 1.0, 1.0 = disabled) | +| `--presence-penalty N` | repeat alpha presence penalty (default: 0.0, 0.0 = disabled) | +| `--frequency-penalty N` | repeat alpha frequency penalty (default: 0.0, 0.0 = disabled) | +| `--dry-multiplier N` | set DRY sampling multiplier (default: 0.0, 0.0 = disabled) | +| `--dry-base N` | set DRY sampling base value (default: 1.75) | +| `--dry-allowed-length N` | set allowed length for DRY sampling (default: 2) | +| `--dry-penalty-last-n N` | set DRY penalty for the last n tokens (default: -1, 0 = disable, -1 = context size) | +| `--dry-sequence-breaker STRING` | add sequence breaker for DRY sampling, clearing out default breakers ('\n', ':', '"', '*') in the process; use "none" to not use any sequence breakers | +| `--dynatemp-range N` | dynamic temperature range (default: 0.0, 0.0 = disabled) | +| `--dynatemp-exp N` | dynamic temperature exponent (default: 1.0) | +| `--mirostat N` | use Mirostat sampling.
Top K, Nucleus and Locally Typical samplers are ignored if used.
(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) | +| `--mirostat-lr N` | Mirostat learning rate, parameter eta (default: 0.1) | +| `--mirostat-ent N` | Mirostat target entropy, parameter tau (default: 5.0) | +| `-l, --logit-bias TOKEN_ID(+/-)BIAS` | modifies the likelihood of token appearing in the completion,
i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',
or `--logit-bias 15043-1` to decrease likelihood of token ' Hello' | +| `--grammar GRAMMAR` | BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '') | +| `--grammar-file FNAME` | file to read grammar from | +| `-j, --json-schema SCHEMA` | JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object
For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead | +| `-jf, --json-schema-file FILE` | File containing a JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object
For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead | + + +### CLI-specific params + +| Argument | Explanation | +| -------- | ----------- | +| `--display-prompt, --no-display-prompt` | whether to print prompt at generation (default: true) | +| `-co, --color [on\|off\|auto]` | Colorize output to distinguish prompt and user input from generations ('on', 'off', or 'auto', default: 'auto')
'auto' enables colors when output is to a terminal | +| `--ctx-checkpoints, --swa-checkpoints N` | max number of context checkpoints to create per slot (default: 8)[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)
(env: LLAMA_ARG_CTX_CHECKPOINTS) | +| `-cram, --cache-ram N` | set the maximum cache size in MiB (default: 8192, -1 - no limit, 0 - disable)[(more info)](https://github.com/ggml-org/llama.cpp/pull/16391)
(env: LLAMA_ARG_CACHE_RAM) | +| `--context-shift, --no-context-shift` | whether to use context shift on infinite text generation (default: disabled)
(env: LLAMA_ARG_CONTEXT_SHIFT) | +| `-sys, --system-prompt PROMPT` | system prompt to use with model (if applicable, depending on chat template) | +| `--show-timings, --no-show-timings` | whether to show timing information after each response (default: true)
(env: LLAMA_ARG_SHOW_TIMINGS) | +| `-sysf, --system-prompt-file FNAME` | a file containing the system prompt (default: none) | +| `-r, --reverse-prompt PROMPT` | halt generation at PROMPT, return control in interactive mode | +| `-sp, --special` | special tokens output enabled (default: false) | +| `-cnv, --conversation, -no-cnv, --no-conversation` | whether to run in conversation mode:
- does not print special tokens and suffix/prefix
- interactive mode is also enabled
(default: auto enabled if chat template is available) | +| `-st, --single-turn` | run conversation for a single turn only, then exit when done
will not be interactive if first turn is predefined with --prompt
(default: false) | +| `-mli, --multiline-input` | allows you to write or paste multiple lines without ending each in '\' | +| `--warmup, --no-warmup` | whether to perform warmup with an empty run (default: enabled) | +| `-mm, --mmproj FILE` | path to a multimodal projector file. see tools/mtmd/README.md
note: if -hf is used, this argument can be omitted
(env: LLAMA_ARG_MMPROJ) | +| `-mmu, --mmproj-url URL` | URL to a multimodal projector file. see tools/mtmd/README.md
(env: LLAMA_ARG_MMPROJ_URL) | +| `--mmproj-auto, --no-mmproj, --no-mmproj-auto` | whether to use multimodal projector file (if available), useful when using -hf (default: enabled)
(env: LLAMA_ARG_MMPROJ_AUTO) | +| `--mmproj-offload, --no-mmproj-offload` | whether to enable GPU offloading for multimodal projector (default: enabled)
(env: LLAMA_ARG_MMPROJ_OFFLOAD) | +| `--image, --audio FILE` | path to an image or audio file. use with multimodal models, use comma-separated values for multiple files | +| `--image-min-tokens N` | minimum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)
(env: LLAMA_ARG_IMAGE_MIN_TOKENS) | +| `--image-max-tokens N` | maximum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)
(env: LLAMA_ARG_IMAGE_MAX_TOKENS) | +| `-otd, --override-tensor-draft =,...` | override tensor buffer type for draft model | +| `-cmoed, --cpu-moe-draft` | keep all Mixture of Experts (MoE) weights in the CPU for the draft model
(env: LLAMA_ARG_CPU_MOE_DRAFT) | +| `-ncmoed, --n-cpu-moe-draft N` | keep the Mixture of Experts (MoE) weights of the first N layers in the CPU for the draft model
(env: LLAMA_ARG_N_CPU_MOE_DRAFT) | +| `--chat-template-kwargs STRING` | sets additional params for the json template parser
(env: LLAMA_CHAT_TEMPLATE_KWARGS) | +| `--jinja, --no-jinja` | whether to use jinja template engine for chat (default: enabled)
(env: LLAMA_ARG_JINJA) | +| `--reasoning-format FORMAT` | controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:
- none: leaves thoughts unparsed in `message.content`
- deepseek: puts thoughts in `message.reasoning_content`
- deepseek-legacy: keeps `` tags in `message.content` while also populating `message.reasoning_content`
(default: auto)
(env: LLAMA_ARG_THINK) | +| `--reasoning-budget N` | controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)
(env: LLAMA_ARG_THINK_BUDGET) | +| `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, bailing-think, bailing2, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, grok-2, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, pangu-embedded, phi3, phi4, rwkv-world, seed_oss, smolvlm, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE) | +| `--chat-template-file JINJA_TEMPLATE_FILE` | set custom jinja chat template file (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, bailing-think, bailing2, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, grok-2, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, pangu-embedded, phi3, phi4, rwkv-world, seed_oss, smolvlm, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE_FILE) | +| `--simple-io` | use basic IO for better compatibility in subprocesses and limited consoles | +| `--draft, --draft-n, --draft-max N` | number of tokens to draft for speculative decoding (default: 16)
(env: LLAMA_ARG_DRAFT_MAX) | +| `--draft-min, --draft-n-min N` | minimum number of draft tokens to use for speculative decoding (default: 0)
(env: LLAMA_ARG_DRAFT_MIN) | +| `--draft-p-min P` | minimum speculative decoding probability (greedy) (default: 0.8)
(env: LLAMA_ARG_DRAFT_P_MIN) | +| `-cd, --ctx-size-draft N` | size of the prompt context for the draft model (default: 0, 0 = loaded from model)
(env: LLAMA_ARG_CTX_SIZE_DRAFT) | +| `-devd, --device-draft ` | comma-separated list of devices to use for offloading the draft model (none = don't offload)
use --list-devices to see a list of available devices | +| `-ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | number of layers to store in VRAM for the draft model
(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) | +| `-md, --model-draft FNAME` | draft model for speculative decoding (default: unused)
(env: LLAMA_ARG_MODEL_DRAFT) | +| `--spec-replace TARGET DRAFT` | translate the string in TARGET into DRAFT if the draft model and main model are not compatible | +| `--gpt-oss-20b-default` | use gpt-oss-20b (note: can download weights from the internet) | +| `--gpt-oss-120b-default` | use gpt-oss-120b (note: can download weights from the internet) | +| `--vision-gemma-4b-default` | use Gemma 3 4B QAT (note: can download weights from the internet) | +| `--vision-gemma-12b-default` | use Gemma 3 12B QAT (note: can download weights from the internet) | + + diff --git a/tools/cli/cli.cpp b/tools/cli/cli.cpp index 128679d020..2f0ffea1c2 100644 --- a/tools/cli/cli.cpp +++ b/tools/cli/cli.cpp @@ -216,7 +216,7 @@ int main(int argc, char ** argv) { ctx_cli.ctx_server.start_loop(); }); - auto inf = ctx_cli.ctx_server.get_info(); + auto inf = ctx_cli.ctx_server.get_meta(); std::string modalities = "text"; if (inf.has_inp_image) { modalities += ", vision"; diff --git a/tools/completion/README.md b/tools/completion/README.md index 57ef394213..391488579e 100644 --- a/tools/completion/README.md +++ b/tools/completion/README.md @@ -5,13 +5,14 @@ This example program allows you to use various LLaMA language models easily and ## Table of Contents 1. [Quick Start](#quick-start) -2. [Common Options](#common-options) -3. [Input Prompts](#input-prompts) -4. [Interaction](#interaction) -5. [Context Management](#context-management) -6. [Generation Flags](#generation-flags) -7. [Performance Tuning and Memory Options](#performance-tuning-and-memory-options) -8. [Additional Options](#additional-options) +2. [Usage](#usage) +3. [Common Options](#common-options) +4. [Input Prompts](#input-prompts) +5. [Interaction](#interaction) +6. [Context Management](#context-management) +7. [Generation Flags](#generation-flags) +8. [Performance Tuning and Memory Options](#performance-tuning-and-memory-options) +9. [Additional Options](#additional-options) ## Quick Start @@ -82,6 +83,177 @@ Once downloaded, place your model in the models folder in llama.cpp. llama-completion.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --ignore-eos -n -1 ``` +## Usage + + + + + +### Common params + +| Argument | Explanation | +| -------- | ----------- | +| `-h, --help, --usage` | print usage and exit | +| `--version` | show version and build info | +| `-cl, --cache-list` | show list of models in cache | +| `--completion-bash` | print source-able bash completion script for llama.cpp | +| `--verbose-prompt` | print a verbose prompt before generation (default: false) | +| `-t, --threads N` | number of CPU threads to use during generation (default: -1)
(env: LLAMA_ARG_THREADS) | +| `-tb, --threads-batch N` | number of threads to use during batch and prompt processing (default: same as --threads) | +| `-C, --cpu-mask M` | CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: "") | +| `-Cr, --cpu-range lo-hi` | range of CPUs for affinity. Complements --cpu-mask | +| `--cpu-strict <0\|1>` | use strict CPU placement (default: 0) | +| `--prio N` | set process/thread priority : low(-1), normal(0), medium(1), high(2), realtime(3) (default: 0) | +| `--poll <0...100>` | use polling level to wait for work (0 - no polling, default: 50) | +| `-Cb, --cpu-mask-batch M` | CPU affinity mask: arbitrarily long hex. Complements cpu-range-batch (default: same as --cpu-mask) | +| `-Crb, --cpu-range-batch lo-hi` | ranges of CPUs for affinity. Complements --cpu-mask-batch | +| `--cpu-strict-batch <0\|1>` | use strict CPU placement (default: same as --cpu-strict) | +| `--prio-batch N` | set process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: 0) | +| `--poll-batch <0\|1>` | use polling to wait for work (default: same as --poll) | +| `-c, --ctx-size N` | size of the prompt context (default: 0, 0 = loaded from model)
(env: LLAMA_ARG_CTX_SIZE) | +| `-n, --predict, --n-predict N` | number of tokens to predict (default: -1, -1 = infinity, -2 = until context filled)
(env: LLAMA_ARG_N_PREDICT) | +| `-b, --batch-size N` | logical maximum batch size (default: 2048)
(env: LLAMA_ARG_BATCH) | +| `-ub, --ubatch-size N` | physical maximum batch size (default: 512)
(env: LLAMA_ARG_UBATCH) | +| `--keep N` | number of tokens to keep from the initial prompt (default: 0, -1 = all) | +| `--swa-full` | use full-size SWA cache (default: false)
[(more info)](https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
(env: LLAMA_ARG_SWA_FULL) | +| `-fa, --flash-attn [on\|off\|auto]` | set Flash Attention use ('on', 'off', or 'auto', default: 'auto')
(env: LLAMA_ARG_FLASH_ATTN) | +| `-p, --prompt PROMPT` | prompt to start generation with; for system message, use -sys | +| `--perf, --no-perf` | whether to enable internal libllama performance timings (default: false)
(env: LLAMA_ARG_PERF) | +| `-f, --file FNAME` | a file containing the prompt (default: none) | +| `-bf, --binary-file FNAME` | binary file containing the prompt (default: none) | +| `-e, --escape, --no-escape` | whether to process escapes sequences (\n, \r, \t, \', \", \\) (default: true) | +| `--rope-scaling {none,linear,yarn}` | RoPE frequency scaling method, defaults to linear unless specified by the model
(env: LLAMA_ARG_ROPE_SCALING_TYPE) | +| `--rope-scale N` | RoPE context scaling factor, expands context by a factor of N
(env: LLAMA_ARG_ROPE_SCALE) | +| `--rope-freq-base N` | RoPE base frequency, used by NTK-aware scaling (default: loaded from model)
(env: LLAMA_ARG_ROPE_FREQ_BASE) | +| `--rope-freq-scale N` | RoPE frequency scaling factor, expands context by a factor of 1/N
(env: LLAMA_ARG_ROPE_FREQ_SCALE) | +| `--yarn-orig-ctx N` | YaRN: original context size of model (default: 0 = model training context size)
(env: LLAMA_ARG_YARN_ORIG_CTX) | +| `--yarn-ext-factor N` | YaRN: extrapolation mix factor (default: -1.0, 0.0 = full interpolation)
(env: LLAMA_ARG_YARN_EXT_FACTOR) | +| `--yarn-attn-factor N` | YaRN: scale sqrt(t) or attention magnitude (default: -1.0)
(env: LLAMA_ARG_YARN_ATTN_FACTOR) | +| `--yarn-beta-slow N` | YaRN: high correction dim or alpha (default: -1.0)
(env: LLAMA_ARG_YARN_BETA_SLOW) | +| `--yarn-beta-fast N` | YaRN: low correction dim or beta (default: -1.0)
(env: LLAMA_ARG_YARN_BETA_FAST) | +| `-kvo, --kv-offload, -nkvo, --no-kv-offload` | whether to enable KV cache offloading (default: enabled)
(env: LLAMA_ARG_KV_OFFLOAD) | +| `--repack, -nr, --no-repack` | whether to enable weight repacking (default: enabled)
(env: LLAMA_ARG_REPACK) | +| `--no-host` | bypass host buffer allowing extra buffers to be used
(env: LLAMA_ARG_NO_HOST) | +| `-ctk, --cache-type-k TYPE` | KV cache data type for K
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_K) | +| `-ctv, --cache-type-v TYPE` | KV cache data type for V
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_V) | +| `-dt, --defrag-thold N` | KV cache defragmentation threshold (DEPRECATED)
(env: LLAMA_ARG_DEFRAG_THOLD) | +| `-np, --parallel N` | number of parallel sequences to decode (default: 1)
(env: LLAMA_ARG_N_PARALLEL) | +| `--mlock` | force system to keep model in RAM rather than swapping or compressing
(env: LLAMA_ARG_MLOCK) | +| `--mmap, --no-mmap` | whether to memory-map model (if disabled, slower load but may reduce pageouts if not using mlock) (default: enabled)
(env: LLAMA_ARG_MMAP) | +| `--numa TYPE` | attempt optimizations that help on some NUMA systems
- distribute: spread execution evenly over all nodes
- isolate: only spawn threads on CPUs on the node that execution started on
- numactl: use the CPU map provided by numactl
if run without this previously, it is recommended to drop the system page cache before using this
see https://github.com/ggml-org/llama.cpp/issues/1437
(env: LLAMA_ARG_NUMA) | +| `-dev, --device ` | comma-separated list of devices to use for offloading (none = don't offload)
use --list-devices to see a list of available devices
(env: LLAMA_ARG_DEVICE) | +| `--list-devices` | print list of available devices and exit | +| `-ot, --override-tensor =,...` | override tensor buffer type | +| `-cmoe, --cpu-moe` | keep all Mixture of Experts (MoE) weights in the CPU
(env: LLAMA_ARG_CPU_MOE) | +| `-ncmoe, --n-cpu-moe N` | keep the Mixture of Experts (MoE) weights of the first N layers in the CPU
(env: LLAMA_ARG_N_CPU_MOE) | +| `-ngl, --gpu-layers, --n-gpu-layers N` | max. number of layers to store in VRAM (default: -1)
(env: LLAMA_ARG_N_GPU_LAYERS) | +| `-sm, --split-mode {none,layer,row}` | how to split the model across multiple GPUs, one of:
- none: use one GPU only
- layer (default): split layers and KV across GPUs
- row: split rows across GPUs
(env: LLAMA_ARG_SPLIT_MODE) | +| `-ts, --tensor-split N0,N1,N2,...` | fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1
(env: LLAMA_ARG_TENSOR_SPLIT) | +| `-mg, --main-gpu INDEX` | the GPU to use for the model (with split-mode = none), or for intermediate results and KV (with split-mode = row) (default: 0)
(env: LLAMA_ARG_MAIN_GPU) | +| `-fit, --fit [on\|off]` | whether to adjust unset arguments to fit in device memory ('on' or 'off', default: 'on')
(env: LLAMA_ARG_FIT) | +| `-fitt, --fit-target MiB` | target margin per device for --fit option, default: 1024
(env: LLAMA_ARG_FIT_TARGET) | +| `-fitc, --fit-ctx N` | minimum ctx size that can be set by --fit option, default: 4096
(env: LLAMA_ARG_FIT_CTX) | +| `--check-tensors` | check model tensor data for invalid values (default: false) | +| `--override-kv KEY=TYPE:VALUE,...` | advanced option to override model metadata by key. to specify multiple overrides, either use comma-separated or repeat this argument.
types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false,tokenizer.ggml.add_eos_token=bool:false | +| `--op-offload, --no-op-offload` | whether to offload host tensor operations to device (default: true) | +| `--lora FNAME` | path to LoRA adapter (use comma-separated values to load multiple adapters) | +| `--lora-scaled FNAME:SCALE,...` | path to LoRA adapter with user defined scaling (format: FNAME:SCALE,...)
note: use comma-separated values | +| `--control-vector FNAME` | add a control vector
note: use comma-separated values to add multiple control vectors | +| `--control-vector-scaled FNAME:SCALE,...` | add a control vector with user defined scaling SCALE
note: use comma-separated values (format: FNAME:SCALE,...) | +| `--control-vector-layer-range START END` | layer range to apply the control vector(s) to, start and end inclusive | +| `-m, --model FNAME` | model path to load
(env: LLAMA_ARG_MODEL) | +| `-mu, --model-url MODEL_URL` | model download url (default: unused)
(env: LLAMA_ARG_MODEL_URL) | +| `-dr, --docker-repo [/][:quant]` | Docker Hub model repository. repo is optional, default to ai/. quant is optional, default to :latest.
example: gemma3
(default: unused)
(env: LLAMA_ARG_DOCKER_REPO) | +| `-hf, -hfr, --hf-repo /[:quant]` | Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.
mmproj is also downloaded automatically if available. to disable, add --no-mmproj
example: unsloth/phi-4-GGUF:q4_k_m
(default: unused)
(env: LLAMA_ARG_HF_REPO) | +| `-hfd, -hfrd, --hf-repo-draft /[:quant]` | Same as --hf-repo, but for the draft model (default: unused)
(env: LLAMA_ARG_HFD_REPO) | +| `-hff, --hf-file FILE` | Hugging Face model file. If specified, it will override the quant in --hf-repo (default: unused)
(env: LLAMA_ARG_HF_FILE) | +| `-hfv, -hfrv, --hf-repo-v /[:quant]` | Hugging Face model repository for the vocoder model (default: unused)
(env: LLAMA_ARG_HF_REPO_V) | +| `-hffv, --hf-file-v FILE` | Hugging Face model file for the vocoder model (default: unused)
(env: LLAMA_ARG_HF_FILE_V) | +| `-hft, --hf-token TOKEN` | Hugging Face access token (default: value from HF_TOKEN environment variable)
(env: HF_TOKEN) | +| `--log-disable` | Log disable | +| `--log-file FNAME` | Log to file
(env: LLAMA_LOG_FILE) | +| `--log-colors [on\|off\|auto]` | Set colored logging ('on', 'off', or 'auto', default: 'auto')
'auto' enables colors when output is to a terminal
(env: LLAMA_LOG_COLORS) | +| `-v, --verbose, --log-verbose` | Set verbosity level to infinity (i.e. log all messages, useful for debugging) | +| `--offline` | Offline mode: forces use of cache, prevents network access
(env: LLAMA_OFFLINE) | +| `-lv, --verbosity, --log-verbosity N` | Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:
- 0: generic output
- 1: error
- 2: warning
- 3: info
- 4: debug
(default: 3)

(env: LLAMA_LOG_VERBOSITY) | +| `--log-prefix` | Enable prefix in log messages
(env: LLAMA_LOG_PREFIX) | +| `--log-timestamps` | Enable timestamps in log messages
(env: LLAMA_LOG_TIMESTAMPS) | +| `-ctkd, --cache-type-k-draft TYPE` | KV cache data type for K for the draft model
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_K_DRAFT) | +| `-ctvd, --cache-type-v-draft TYPE` | KV cache data type for V for the draft model
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_V_DRAFT) | + + +### Sampling params + +| Argument | Explanation | +| -------- | ----------- | +| `--samplers SAMPLERS` | samplers that will be used for generation in the order, separated by ';'
(default: penalties;dry;top_n_sigma;top_k;typ_p;top_p;min_p;xtc;temperature) | +| `-s, --seed SEED` | RNG seed (default: -1, use random seed for -1) | +| `--sampler-seq, --sampling-seq SEQUENCE` | simplified sequence for samplers that will be used (default: edskypmxt) | +| `--ignore-eos` | ignore end of stream token and continue generating (implies --logit-bias EOS-inf) | +| `--temp N` | temperature (default: 0.8) | +| `--top-k N` | top-k sampling (default: 40, 0 = disabled)
(env: LLAMA_ARG_TOP_K) | +| `--top-p N` | top-p sampling (default: 0.9, 1.0 = disabled) | +| `--min-p N` | min-p sampling (default: 0.1, 0.0 = disabled) | +| `--top-nsigma N` | top-n-sigma sampling (default: -1.0, -1.0 = disabled) | +| `--xtc-probability N` | xtc probability (default: 0.0, 0.0 = disabled) | +| `--xtc-threshold N` | xtc threshold (default: 0.1, 1.0 = disabled) | +| `--typical N` | locally typical sampling, parameter p (default: 1.0, 1.0 = disabled) | +| `--repeat-last-n N` | last n tokens to consider for penalize (default: 64, 0 = disabled, -1 = ctx_size) | +| `--repeat-penalty N` | penalize repeat sequence of tokens (default: 1.0, 1.0 = disabled) | +| `--presence-penalty N` | repeat alpha presence penalty (default: 0.0, 0.0 = disabled) | +| `--frequency-penalty N` | repeat alpha frequency penalty (default: 0.0, 0.0 = disabled) | +| `--dry-multiplier N` | set DRY sampling multiplier (default: 0.0, 0.0 = disabled) | +| `--dry-base N` | set DRY sampling base value (default: 1.75) | +| `--dry-allowed-length N` | set allowed length for DRY sampling (default: 2) | +| `--dry-penalty-last-n N` | set DRY penalty for the last n tokens (default: -1, 0 = disable, -1 = context size) | +| `--dry-sequence-breaker STRING` | add sequence breaker for DRY sampling, clearing out default breakers ('\n', ':', '"', '*') in the process; use "none" to not use any sequence breakers | +| `--dynatemp-range N` | dynamic temperature range (default: 0.0, 0.0 = disabled) | +| `--dynatemp-exp N` | dynamic temperature exponent (default: 1.0) | +| `--mirostat N` | use Mirostat sampling.
Top K, Nucleus and Locally Typical samplers are ignored if used.
(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) | +| `--mirostat-lr N` | Mirostat learning rate, parameter eta (default: 0.1) | +| `--mirostat-ent N` | Mirostat target entropy, parameter tau (default: 5.0) | +| `-l, --logit-bias TOKEN_ID(+/-)BIAS` | modifies the likelihood of token appearing in the completion,
i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',
or `--logit-bias 15043-1` to decrease likelihood of token ' Hello' | +| `--grammar GRAMMAR` | BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '') | +| `--grammar-file FNAME` | file to read grammar from | +| `-j, --json-schema SCHEMA` | JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object
For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead | +| `-jf, --json-schema-file FILE` | File containing a JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object
For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead | + + +### Completion-specific params + +| Argument | Explanation | +| -------- | ----------- | +| `--display-prompt, --no-display-prompt` | whether to print prompt at generation (default: true) | +| `-co, --color [on\|off\|auto]` | Colorize output to distinguish prompt and user input from generations ('on', 'off', or 'auto', default: 'auto')
'auto' enables colors when output is to a terminal | +| `--context-shift, --no-context-shift` | whether to use context shift on infinite text generation (default: disabled)
(env: LLAMA_ARG_CONTEXT_SHIFT) | +| `-sys, --system-prompt PROMPT` | system prompt to use with model (if applicable, depending on chat template) | +| `-sysf, --system-prompt-file FNAME` | a file containing the system prompt (default: none) | +| `-ptc, --print-token-count N` | print token count every N tokens (default: -1) | +| `--prompt-cache FNAME` | file to cache prompt state for faster startup (default: none) | +| `--prompt-cache-all` | if specified, saves user input and generations to cache as well | +| `--prompt-cache-ro` | if specified, uses the prompt cache but does not update it | +| `-r, --reverse-prompt PROMPT` | halt generation at PROMPT, return control in interactive mode | +| `-sp, --special` | special tokens output enabled (default: false) | +| `-cnv, --conversation, -no-cnv, --no-conversation` | whether to run in conversation mode:
- does not print special tokens and suffix/prefix
- interactive mode is also enabled
(default: auto enabled if chat template is available) | +| `-st, --single-turn` | run conversation for a single turn only, then exit when done
will not be interactive if first turn is predefined with --prompt
(default: false) | +| `-i, --interactive` | run in interactive mode (default: false) | +| `-if, --interactive-first` | run in interactive mode and wait for input right away (default: false) | +| `-mli, --multiline-input` | allows you to write or paste multiple lines without ending each in '\' | +| `--in-prefix-bos` | prefix BOS to user inputs, preceding the `--in-prefix` string | +| `--in-prefix STRING` | string to prefix user inputs with (default: empty) | +| `--in-suffix STRING` | string to suffix after user inputs with (default: empty) | +| `--warmup, --no-warmup` | whether to perform warmup with an empty run (default: enabled) | +| `-gan, --grp-attn-n N` | group-attention factor (default: 1)
(env: LLAMA_ARG_GRP_ATTN_N) | +| `-gaw, --grp-attn-w N` | group-attention width (default: 512)
(env: LLAMA_ARG_GRP_ATTN_W) | +| `--jinja, --no-jinja` | whether to use jinja template engine for chat (default: disabled)
(env: LLAMA_ARG_JINJA) | +| `--reasoning-format FORMAT` | controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:
- none: leaves thoughts unparsed in `message.content`
- deepseek: puts thoughts in `message.reasoning_content`
- deepseek-legacy: keeps `` tags in `message.content` while also populating `message.reasoning_content`
(default: auto)
(env: LLAMA_ARG_THINK) | +| `--reasoning-budget N` | controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)
(env: LLAMA_ARG_THINK_BUDGET) | +| `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, bailing-think, bailing2, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, grok-2, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, pangu-embedded, phi3, phi4, rwkv-world, seed_oss, smolvlm, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE) | +| `--chat-template-file JINJA_TEMPLATE_FILE` | set custom jinja chat template file (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, bailing-think, bailing2, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, grok-2, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, pangu-embedded, phi3, phi4, rwkv-world, seed_oss, smolvlm, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE_FILE) | +| `--simple-io` | use basic IO for better compatibility in subprocesses and limited consoles | + + + ## Common Options In this section, we cover the most commonly used options for running the `llama-completion` program with the LLaMA models: diff --git a/tools/llama-bench/llama-bench.cpp b/tools/llama-bench/llama-bench.cpp index 0be6ed6948..b431c7f31b 100644 --- a/tools/llama-bench/llama-bench.cpp +++ b/tools/llama-bench/llama-bench.cpp @@ -2102,6 +2102,8 @@ int main(int argc, char ** argv) { struct ggml_threadpool_params tpp = ggml_threadpool_params_default(t.n_threads); if (!parse_cpu_mask(t.cpu_mask, tpp.cpumask)) { fprintf(stderr, "%s: failed to parse cpu-mask: %s\n", __func__, t.cpu_mask.c_str()); + llama_free(ctx); + llama_model_free(lmodel); exit(1); } tpp.strict_cpu = t.cpu_strict; @@ -2111,6 +2113,8 @@ int main(int argc, char ** argv) { struct ggml_threadpool * threadpool = ggml_threadpool_new_fn(&tpp); if (!threadpool) { fprintf(stderr, "%s: threadpool create failed : n_threads %d\n", __func__, tpp.n_threads); + llama_free(ctx); + llama_model_free(lmodel); exit(1); } @@ -2126,6 +2130,8 @@ int main(int argc, char ** argv) { bool res = test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads); if (!res) { fprintf(stderr, "%s: error: failed to run prompt warmup\n", __func__); + llama_free(ctx); + llama_model_free(lmodel); exit(1); } } @@ -2136,6 +2142,8 @@ int main(int argc, char ** argv) { bool res = test_gen(ctx, 1, t.n_threads); if (!res) { fprintf(stderr, "%s: error: failed to run gen warmup\n", __func__); + llama_free(ctx); + llama_model_free(lmodel); exit(1); } } @@ -2164,6 +2172,8 @@ int main(int argc, char ** argv) { bool res = test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads); if (!res) { fprintf(stderr, "%s: error: failed to run depth\n", __func__); + llama_free(ctx); + llama_model_free(lmodel); exit(1); } @@ -2189,6 +2199,8 @@ int main(int argc, char ** argv) { bool res = test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads); if (!res) { fprintf(stderr, "%s: error: failed to run prompt\n", __func__); + llama_free(ctx); + llama_model_free(lmodel); exit(1); } } @@ -2200,6 +2212,8 @@ int main(int argc, char ** argv) { bool res = test_gen(ctx, t.n_gen, t.n_threads); if (!res) { fprintf(stderr, "%s: error: failed to run gen\n", __func__); + llama_free(ctx); + llama_model_free(lmodel); exit(1); } } diff --git a/tools/server/README.md b/tools/server/README.md index 71f1d4777c..1ae5eae4c6 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -23,9 +23,11 @@ For the ful list of features, please refer to [server's changelog](https://githu ## Usage - + -**Common params** + + +### Common params | Argument | Explanation | | -------- | ----------- | @@ -38,13 +40,13 @@ For the ful list of features, please refer to [server's changelog](https://githu | `-tb, --threads-batch N` | number of threads to use during batch and prompt processing (default: same as --threads) | | `-C, --cpu-mask M` | CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: "") | | `-Cr, --cpu-range lo-hi` | range of CPUs for affinity. Complements --cpu-mask | -| `--cpu-strict <0\|1>` | use strict CPU placement (default: 0)
| -| `--prio N` | set process/thread priority : low(-1), normal(0), medium(1), high(2), realtime(3) (default: 0)
| -| `--poll <0...100>` | use polling level to wait for work (0 - no polling, default: 50)
| +| `--cpu-strict <0\|1>` | use strict CPU placement (default: 0) | +| `--prio N` | set process/thread priority : low(-1), normal(0), medium(1), high(2), realtime(3) (default: 0) | +| `--poll <0...100>` | use polling level to wait for work (0 - no polling, default: 50) | | `-Cb, --cpu-mask-batch M` | CPU affinity mask: arbitrarily long hex. Complements cpu-range-batch (default: same as --cpu-mask) | | `-Crb, --cpu-range-batch lo-hi` | ranges of CPUs for affinity. Complements --cpu-mask-batch | | `--cpu-strict-batch <0\|1>` | use strict CPU placement (default: same as --cpu-strict) | -| `--prio-batch N` | set process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: 0)
| +| `--prio-batch N` | set process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: 0) | | `--poll-batch <0\|1>` | use polling to wait for work (default: same as --poll) | | `-c, --ctx-size N` | size of the prompt context (default: 0, 0 = loaded from model)
(env: LLAMA_ARG_CTX_SIZE) | | `-n, --predict, --n-predict N` | number of tokens to predict (default: -1, -1 = infinity)
(env: LLAMA_ARG_N_PREDICT) | @@ -114,7 +116,7 @@ For the ful list of features, please refer to [server's changelog](https://githu | `-ctvd, --cache-type-v-draft TYPE` | KV cache data type for V for the draft model
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_V_DRAFT) | -**Sampling params** +### Sampling params | Argument | Explanation | | -------- | ----------- | @@ -138,7 +140,7 @@ For the ful list of features, please refer to [server's changelog](https://githu | `--dry-base N` | set DRY sampling base value (default: 1.75) | | `--dry-allowed-length N` | set allowed length for DRY sampling (default: 2) | | `--dry-penalty-last-n N` | set DRY penalty for the last n tokens (default: -1, 0 = disable, -1 = context size) | -| `--dry-sequence-breaker STRING` | add sequence breaker for DRY sampling, clearing out default breakers ('\n', ':', '"', '*') in the process; use "none" to not use any sequence breakers
| +| `--dry-sequence-breaker STRING` | add sequence breaker for DRY sampling, clearing out default breakers ('\n', ':', '"', '*') in the process; use "none" to not use any sequence breakers | | `--dynatemp-range N` | dynamic temperature range (default: 0.0, 0.0 = disabled) | | `--dynatemp-exp N` | dynamic temperature exponent (default: 1.0) | | `--mirostat N` | use Mirostat sampling.
Top K, Nucleus and Locally Typical samplers are ignored if used.
(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) | @@ -151,7 +153,7 @@ For the ful list of features, please refer to [server's changelog](https://githu | `-jf, --json-schema-file FILE` | File containing a JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object
For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead | -**Server-specific params** +### Server-specific params | Argument | Explanation | | -------- | ----------- | @@ -159,7 +161,7 @@ For the ful list of features, please refer to [server's changelog](https://githu | `-cram, --cache-ram N` | set the maximum cache size in MiB (default: 8192, -1 - no limit, 0 - disable)[(more info)](https://github.com/ggml-org/llama.cpp/pull/16391)
(env: LLAMA_ARG_CACHE_RAM) | | `-kvu, --kv-unified` | use single unified KV buffer shared across all sequences (default: enabled if number of slots is auto)
(env: LLAMA_ARG_KV_UNIFIED) | | `--context-shift, --no-context-shift` | whether to use context shift on infinite text generation (default: disabled)
(env: LLAMA_ARG_CONTEXT_SHIFT) | -| `-r, --reverse-prompt PROMPT` | halt generation at PROMPT, return control in interactive mode
| +| `-r, --reverse-prompt PROMPT` | halt generation at PROMPT, return control in interactive mode | | `-sp, --special` | special tokens output enabled (default: false) | | `--warmup, --no-warmup` | whether to perform warmup with an empty run (default: enabled) | | `--spm-infill` | use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: disabled) | @@ -208,8 +210,9 @@ For the ful list of features, please refer to [server's changelog](https://githu | `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, bailing-think, bailing2, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, grok-2, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, pangu-embedded, phi3, phi4, rwkv-world, seed_oss, smolvlm, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE) | | `--chat-template-file JINJA_TEMPLATE_FILE` | set custom jinja chat template file (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, bailing-think, bailing2, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, grok-2, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, pangu-embedded, phi3, phi4, rwkv-world, seed_oss, smolvlm, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE_FILE) | | `--prefill-assistant, --no-prefill-assistant` | whether to prefill the assistant's response if the last message is an assistant message (default: prefill enabled)
when this flag is set, if the last message is an assistant message then it will be treated as a full message and not prefilled

(env: LLAMA_ARG_PREFILL_ASSISTANT) | -| `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.10, 0.0 = disabled)
| +| `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.10, 0.0 = disabled) | | `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) | +| `--sleep-idle-seconds SECONDS` | number of seconds of idleness after which the server will sleep (default: -1; -1 = disabled) | | `-td, --threads-draft N` | number of threads to use during generation (default: same as --threads) | | `-tbd, --threads-batch-draft N` | number of threads to use during batch and prompt processing (default: same as --threads-draft) | | `--draft, --draft-n, --draft-max N` | number of tokens to draft for speculative decoding (default: 16)
(env: LLAMA_ARG_DRAFT_MAX) | @@ -234,6 +237,7 @@ For the ful list of features, please refer to [server's changelog](https://githu | `--vision-gemma-4b-default` | use Gemma 3 4B QAT (note: can download weights from the internet) | | `--vision-gemma-12b-default` | use Gemma 3 12B QAT (note: can download weights from the internet) | + Note: If both command line argument and environment variable are both set for the same param, the argument will take precedence over env var. @@ -1567,7 +1571,6 @@ Load a model Payload: - `model`: name of the model to be loaded. -- `extra_args`: (optional) an array of additional arguments to be passed to the model instance. Note: you must start the server with `--models-allow-extra-args` to enable this feature. ```json { diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index e72f2728db..97b7d67f3e 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index af21e3d45c..e4a0be44cc 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -115,26 +115,14 @@ bool lora_should_clear_cache( !lora_all_alora(next)); } -std::vector parse_lora_request( - const std::vector & lora_base, - const json & data) { - std::vector lora(lora_base); - int max_idx = lora.size(); - - // clear existing value - for (auto & entry : lora) { - entry.scale = 0.0f; - } +std::map parse_lora_request(const json & data) { + std::map lora; // set value for (const auto & entry : data) { int id = json_value(entry, "id", -1); float scale = json_value(entry, "scale", 0.0f); - if (0 <= id && id < max_idx) { - lora[id].scale = scale; - } else { - throw std::runtime_error("invalid adapter id"); - } + lora[id] = scale; } return lora; @@ -1440,7 +1428,7 @@ std::string safe_json_to_str(const json & data) { // TODO: reuse llama_detokenize template -static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { +static std::string tokens_to_str(const llama_vocab * ctx, Iter begin, Iter end) { std::string ret; for (; begin != end; ++begin) { ret += common_token_to_piece(ctx, *begin); @@ -1450,7 +1438,12 @@ static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { } std::string tokens_to_str(llama_context * ctx, const llama_tokens & tokens) { - return tokens_to_str(ctx, tokens.begin(), tokens.end()); + auto model = llama_get_model(ctx); + return tokens_to_str(llama_model_get_vocab(model), tokens.begin(), tokens.end()); +} + +std::string tokens_to_str(const llama_vocab * vocab, const llama_tokens & tokens) { + return tokens_to_str(vocab, tokens.begin(), tokens.end()); } // format incomplete utf-8 multibyte character for output diff --git a/tools/server/server-common.h b/tools/server/server-common.h index 0629bb5edd..152a2a3c46 100644 --- a/tools/server/server-common.h +++ b/tools/server/server-common.h @@ -107,9 +107,7 @@ bool lora_should_clear_cache( const std::vector & current, const std::vector & next); -std::vector parse_lora_request( - const std::vector & lora_base, - const json & data); +std::map parse_lora_request(const json & data); bool are_lora_equal( const std::vector & l1, @@ -325,6 +323,7 @@ std::vector get_token_probabilities(llama_context * ctx, int i std::string safe_json_to_str(const json & data); std::string tokens_to_str(llama_context * ctx, const llama_tokens & tokens); +std::string tokens_to_str(const llama_vocab * vocab, const llama_tokens & tokens); // format incomplete utf-8 multibyte character for output std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 0331cb08e6..d480ff0ed2 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -507,19 +507,42 @@ struct server_metrics { // struct server_context_impl { + friend struct server_context; + +public: + // only use these pointers outside of this class: + // - when not in sleeping state + // - and, with thread-safe APIs (e.g., tokenizer calls) + llama_model * model = nullptr; + mtmd_context * mctx = nullptr; + const llama_vocab * vocab = nullptr; + + server_queue queue_tasks; + server_response queue_results; + + common_chat_templates_ptr chat_templates; + oaicompat_parser_options oai_parser_opt; + + ~server_context_impl() { + if (!sleeping) { + // destroy() is already called when entering sleeping state + // we don't call it again here to avoid double free + destroy(); + } + } + +private: + // note: accessing these fields outside of this class is not thread-safe + // use server_context methods instead + common_params params_base; // note: keep these alive - they determine the lifetime of the model, context, etc. common_init_result_ptr llama_init; common_init_result_ptr llama_init_dft; - llama_model * model = nullptr; llama_context * ctx = nullptr; - // multimodal - mtmd_context * mctx = nullptr; - - const llama_vocab * vocab = nullptr; bool vocab_dft_compatible = true; llama_model * model_dft = nullptr; @@ -537,35 +560,19 @@ struct server_context_impl { int slots_debug = 0; - server_queue queue_tasks; - server_response queue_results; - std::unique_ptr prompt_cache; server_metrics metrics; - // cached responses for HTTP API (read-only from HTTP threads) - json json_server_props = json::object(); - json json_server_model_meta = json::object(); + json json_webui_settings = json::object(); // Necessary similarity of prompt for slot selection float slot_prompt_similarity = 0.0f; std::string model_name; // name of the loaded model, to be used by API - common_chat_templates_ptr chat_templates; - oaicompat_parser_options oai_parser_opt; - bool sleeping = false; - ~server_context_impl() { - if (!sleeping) { - // destroy() is already called when entering sleeping state - // we don't call it again here to avoid double free - destroy(); - } - } - void destroy() { llama_init.reset(); ctx = nullptr; @@ -871,17 +878,7 @@ struct server_context_impl { metrics.init(); - if (!populate_json_responses()) { - SRV_ERR("%s", "failed to populate JSON responses\n"); - return false; - } - - return true; - } - - bool populate_json_responses() { // populate webui settings - json json_webui_settings = json::object(); { if (!params_base.webui_config_json.empty()) { try { @@ -893,53 +890,6 @@ struct server_context_impl { } } - // populate server properties - { - task_params params; - params.sampling = params_base.sampling; - json default_generation_settings_for_props = json { - {"params", params.to_json(true)}, - {"n_ctx", get_slot_n_ctx()}, - }; - - json_server_props = { - { "default_generation_settings", default_generation_settings_for_props }, - { "total_slots", params_base.n_parallel }, - { "model_alias", model_name }, - { "model_path", params_base.model.path }, - { "modalities", json { - {"vision", oai_parser_opt.allow_image}, - {"audio", oai_parser_opt.allow_audio}, - } }, - { "endpoint_slots", params_base.endpoint_slots }, - { "endpoint_props", params_base.endpoint_props }, - { "endpoint_metrics", params_base.endpoint_metrics }, - { "webui", params_base.webui }, - { "webui_settings", json_webui_settings }, - { "chat_template", common_chat_templates_source(chat_templates.get()) }, - { "bos_token", common_token_to_piece(ctx, llama_vocab_bos(vocab), /* special= */ true)}, - { "eos_token", common_token_to_piece(ctx, llama_vocab_eos(vocab), /* special= */ true)}, - { "build_info", build_info }, - }; - if (params_base.use_jinja) { - if (auto tool_use_src = common_chat_templates_source(chat_templates.get(), "tool_use")) { - json_server_props["chat_template_tool_use"] = tool_use_src; - } - } - } - - // populate model metadata - { - json_server_model_meta = { - {"vocab_type", llama_vocab_type (vocab)}, - {"n_vocab", llama_vocab_n_tokens (vocab)}, - {"n_ctx_train", llama_model_n_ctx_train(model)}, - {"n_embd", llama_model_n_embd (model)}, - {"n_params", llama_model_n_params (model)}, - {"size", llama_model_size (model)}, - }; - } - return true; } @@ -1098,18 +1048,37 @@ struct server_context_impl { return res; } + std::vector construct_lora_list(const std::map & config) { + std::vector output = params_base.lora_adapters; // copy + for (size_t i = 0; i < output.size(); ++i) { + auto it = config.find(i); + if (it != config.end()) { + output[i].scale = it->second; + } else { + output[i].scale = 0.0f; + } + } + return output; + } + bool launch_slot_with_task(server_slot & slot, server_task && task) { slot.reset(); - if (!are_lora_equal(task.params.lora, slot.lora)) { - // if lora has changed, check to see if the cache should be cleared - if (lora_should_clear_cache(slot.lora, task.params.lora)) { - SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), task.params.lora.size()); - slot.prompt.tokens.clear(); - } else { - SLT_INF(slot, "keeping cache for alora. %zu target loras\n", task.params.lora.size()); + // process per-request lora adapters + if (!task.params.lora.empty()) { + auto task_loras = construct_lora_list(task.params.lora); + if (!are_lora_equal(task_loras, slot.lora)) { + // if lora has changed, check to see if the cache should be cleared + if (lora_should_clear_cache(slot.lora, task_loras)) { + SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), task.params.lora.size()); + slot.prompt.tokens.clear(); + } else { + SLT_INF(slot, "keeping cache for alora. %zu target loras\n", task_loras.size()); + } + slot.lora = task_loras; } - slot.lora = task.params.lora; + } else { + slot.lora = params_base.lora_adapters; } // if using alora, make sure it's only a single one requested and active @@ -1858,9 +1827,41 @@ struct server_context_impl { res->n_erased = n_erased; queue_results.send(std::move(res)); } break; + case SERVER_TASK_TYPE_GET_LORA: + { + // TODO @ngxson : make lora_adapters a dedicated member of server_context + auto & loras = params_base.lora_adapters; + auto res = std::make_unique(); + res->id = task.id; + for (size_t i = 0; i < loras.size(); ++i) { + auto & lora = loras[i]; + std::string alora_invocation_string = ""; + const uint64_t n_alora_tokens = llama_adapter_get_alora_n_invocation_tokens(lora.ptr); + llama_tokens alora_invocation_tokens; + if (n_alora_tokens) { + const llama_token * alora_tokens = llama_adapter_get_alora_invocation_tokens(lora.ptr); + for (uint64_t j = 0; j < n_alora_tokens; ++j) { + alora_invocation_string += common_token_to_piece(vocab, alora_tokens[j]); + alora_invocation_tokens.push_back(alora_tokens[j]); + } + } + res->loras.push_back(server_task_result_get_lora::lora{ + lora, + alora_invocation_string, + alora_invocation_tokens, + }); + } + queue_results.send(std::move(res)); + } break; case SERVER_TASK_TYPE_SET_LORA: { - params_base.lora_adapters = std::move(task.set_lora); + auto new_loras = construct_lora_list(task.set_lora); + // logging + for (size_t i = 0; i < new_loras.size(); ++i) { + SRV_INF("set lora adapter idx=%zu scale=%f\n", i, new_loras[i].scale); + } + // TODO @ngxson : make lora_adapters a dedicated member of server_context + params_base.lora_adapters = new_loras; auto res = std::make_unique(); res->id = task.id; queue_results.send(std::move(res)); @@ -2331,6 +2332,12 @@ struct server_context_impl { slot.n_prompt_tokens_processed = 0; slot.prompt.tokens.keep_first(n_past); + + // send initial 0% progress update if needed + // this is to signal the client that the request has started processing + if (slot.task->params.stream && slot.task->params.return_progress) { + send_partial_response(slot, {}, true); + } } if (!slot.can_split()) { @@ -2800,12 +2807,40 @@ server_response_reader server_context::get_response_reader() { return impl->get_response_reader(); } -server_context_info server_context::get_info() const { - return server_context_info { - /* build_info */ build_info, - /* model_name */ impl->model_name, - /* has_inp_image */ impl->oai_parser_opt.allow_image, - /* has_inp_audio */ impl->oai_parser_opt.allow_audio, +server_context_meta server_context::get_meta() const { + auto tool_use_src = common_chat_templates_source(impl->chat_templates.get(), "tool_use"); + + auto bos_id = llama_vocab_bos(impl->vocab); + auto eos_id = llama_vocab_eos(impl->vocab); + auto bos_token_str = bos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx, bos_id, true) : ""; + auto eos_token_str = eos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx, eos_id, true) : ""; + + return server_context_meta { + /* build_info */ build_info, + /* model_name */ impl->model_name, + /* model_path */ impl->params_base.model.path, + /* has_mtmd */ impl->mctx != nullptr, + /* has_inp_image */ impl->oai_parser_opt.allow_image, + /* has_inp_audio */ impl->oai_parser_opt.allow_audio, + /* json_webui_settings */ impl->json_webui_settings, + /* slot_n_ctx */ impl->get_slot_n_ctx(), + /* pooling_type */ llama_pooling_type(impl->ctx), + + /* chat_template */ common_chat_templates_source(impl->chat_templates.get()), + /* chat_template_tool_use */ tool_use_src ? tool_use_src : "", + + /* bos_token_str */ bos_token_str, + /* eos_token_str */ eos_token_str, + /* fim_pre_token */ llama_vocab_fim_pre(impl->vocab), + /* fim_sub_token */ llama_vocab_fim_suf(impl->vocab), + /* fim_mid_token */ llama_vocab_fim_mid(impl->vocab), + + /* model_vocab_type */ llama_vocab_type(impl->vocab), + /* model_vocab_n_tokens */ llama_vocab_n_tokens(impl->vocab), + /* model_n_ctx_train */ llama_model_n_ctx_train(impl->model), + /* model_n_embd_inp */ llama_model_n_embd(impl->model), + /* model_n_params */ llama_model_n_params(impl->model), + /* model_size */ llama_model_size(impl->model), }; } @@ -2815,12 +2850,12 @@ server_context_info server_context::get_info() const { // may have bypass_sleep = true if the task does not use ctx_server struct server_res_generator : server_http_res { server_response_reader rd; - server_res_generator(server_context_impl & ctx_server, bool bypass_sleep = false) - : rd(ctx_server.queue_tasks, ctx_server.queue_results, HTTP_POLLING_SECONDS) { + server_res_generator(server_queue & queue_tasks, server_response & queue_results, int sleep_idle_seconds, bool bypass_sleep = false) + : rd(queue_tasks, queue_results, HTTP_POLLING_SECONDS) { // fast path in case sleeping is disabled - bypass_sleep |= ctx_server.params_base.sleep_idle_seconds < 0; + bypass_sleep |= sleep_idle_seconds < 0; if (!bypass_sleep) { - ctx_server.queue_tasks.wait_until_no_sleep(); + queue_tasks.wait_until_no_sleep(); } } void ok(const json & response_data) { @@ -2839,17 +2874,15 @@ struct server_res_generator : server_http_res { // server_routes // -static std::unique_ptr handle_completions_impl( - std::unique_ptr && res_ptr, - server_context_impl & ctx_server, +std::unique_ptr server_routes::handle_completions_impl( + const server_http_req & req, server_task_type type, const json & data, const std::vector & files, - const std::function & should_stop, task_response_type res_type) { GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); - auto res = std::move(res_ptr); + auto res = create_response(); auto completion_id = gen_chatcmplid(); auto & rd = res->rd; @@ -2871,32 +2904,30 @@ static std::unique_ptr handle_completions_impl( inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); } tasks.reserve(inputs.size()); - int idx = 0; for (size_t i = 0; i < inputs.size(); i++) { server_task task = server_task(type); - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = idx++; + task.id = rd.get_new_id(); task.tokens = std::move(inputs[i]); task.params = server_task::params_from_json_cmpl( - ctx_server.ctx, - ctx_server.params_base, + ctx_server.vocab, + params, + meta->slot_n_ctx, data); task.id_slot = json_value(data, "id_slot", -1); // OAI-compat task.params.res_type = res_type; task.params.oaicompat_cmpl_id = completion_id; - task.params.oaicompat_model = ctx_server.model_name; + task.params.oaicompat_model = meta->model_name; if (task.params.n_cmpl > 1) { task.n_children = task.params.n_cmpl - 1; for (size_t j = 0; j < task.n_children; j++) { server_task child = task.create_child( task.id, - ctx_server.queue_tasks.get_new_id(), - idx++); + rd.get_new_id()); tasks.push_back(std::move(child)); } } @@ -2914,7 +2945,7 @@ static std::unique_ptr handle_completions_impl( if (!stream) { // non-stream, wait for the results - auto all_results = rd.wait_for_all(should_stop); + auto all_results = rd.wait_for_all(req.should_stop); if (all_results.is_terminated) { return res; // connection is closed } else if (all_results.error) { @@ -2946,7 +2977,7 @@ static std::unique_ptr handle_completions_impl( // in streaming mode, the first error must be treated as non-stream response // this is to match the OAI API behavior // ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309 - server_task_result_ptr first_result = rd.next(should_stop); + server_task_result_ptr first_result = rd.next(req.should_stop); if (first_result == nullptr) { return res; // connection is closed } else if (first_result->is_error()) { @@ -2969,7 +3000,7 @@ static std::unique_ptr handle_completions_impl( } res->status = 200; res->content_type = "text/event-stream"; - res->next = [res_this = res.get(), res_type, &should_stop](std::string & output) -> bool { + res->next = [res_this = res.get(), res_type, &req](std::string & output) -> bool { static auto format_error = [](task_response_type res_type, const json & res_json) { if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { return format_anthropic_sse({ @@ -2982,7 +3013,7 @@ static std::unique_ptr handle_completions_impl( }; try { - if (should_stop()) { + if (req.should_stop()) { SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); return false; // should_stop condition met } @@ -3011,7 +3042,7 @@ static std::unique_ptr handle_completions_impl( } // receive subsequent results - auto result = rd.next(should_stop); + auto result = rd.next(req.should_stop); if (result == nullptr) { SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); return false; // should_stop condition met @@ -3052,37 +3083,51 @@ static std::unique_ptr handle_completions_impl( return res; } +std::unique_ptr server_routes::create_response(bool bypass_sleep) { + return std::make_unique(queue_tasks, queue_results, params.sleep_idle_seconds, bypass_sleep); +} + +server_routes::server_routes(const common_params & params, server_context & ctx_server) + : params(params), + ctx_server(*ctx_server.impl), + queue_tasks(ctx_server.impl->queue_tasks), + queue_results(ctx_server.impl->queue_results) { + init_routes(); +} + void server_routes::init_routes() { - // IMPORTANT: all lambda functions must start with std::make_unique + // IMPORTANT: all lambda functions must start with create_response() // this is to ensure that the server_res_generator can handle sleeping case correctly this->get_health = [this](const server_http_req &) { // error and loading states are handled by middleware - auto res = std::make_unique(ctx_server, true); + auto res = create_response(true); + + // this endpoint can be accessed during sleeping + // the next LOC is to avoid someone accidentally use ctx_server + bool server_ctx; // do NOT delete this line + GGML_UNUSED(server_ctx); + res->ok({{"status", "ok"}}); return res; }; - this->get_metrics = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server); + this->get_metrics = [this](const server_http_req & req) { + auto res = create_response(); if (!params.endpoint_metrics) { res->error(format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED)); return res; } // request slots data using task queue - // TODO: use server_response_reader - int task_id = ctx_server.queue_tasks.get_new_id(); { server_task task(SERVER_TASK_TYPE_METRICS); - task.id = task_id; - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task), true); // high-priority task + task.id = res->rd.get_new_id(); + res->rd.post_task(std::move(task), true); // high-priority task } // get the result - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); + auto result = res->rd.next(req.should_stop); if (result->is_error()) { res->error(result->to_json()); @@ -3168,24 +3213,21 @@ void server_routes::init_routes() { }; this->get_slots = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); + auto res = create_response(); if (!params.endpoint_slots) { res->error(format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED)); return res; } // request slots data using task queue - int task_id = ctx_server.queue_tasks.get_new_id(); { server_task task(SERVER_TASK_TYPE_METRICS); - task.id = task_id; - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task), true); // high-priority task + task.id = res->rd.get_new_id(); + res->rd.post_task(std::move(task), true); // high-priority task } // get the result - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); + auto result = res->rd.next(req.should_stop); if (result->is_error()) { res->error(result->to_json()); @@ -3209,7 +3251,7 @@ void server_routes::init_routes() { }; this->post_slots = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); + auto res = create_response(); if (params.slot_save_path.empty()) { res->error(format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED)); return res; @@ -3240,15 +3282,51 @@ void server_routes::init_routes() { }; this->get_props = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server, true); - auto props = ctx_server.json_server_props; - props["is_sleeping"] = ctx_server.queue_tasks.is_sleeping(); + auto res = create_response(true); + + // this endpoint can be accessed during sleeping + // the next LOC is to avoid someone accidentally use ctx_server + bool server_ctx; // do NOT delete this line + GGML_UNUSED(server_ctx); + + task_params tparams; + tparams.sampling = params.sampling; + json default_generation_settings_for_props = json { + { "params", tparams.to_json(true) }, + { "n_ctx", meta->slot_n_ctx }, + }; + + json props = { + { "default_generation_settings", default_generation_settings_for_props }, + { "total_slots", params.n_parallel }, + { "model_alias", meta->model_name }, + { "model_path", meta->model_path }, + { "modalities", json { + {"vision", meta->has_inp_image}, + {"audio", meta->has_inp_audio}, + } }, + { "endpoint_slots", params.endpoint_slots }, + { "endpoint_props", params.endpoint_props }, + { "endpoint_metrics", params.endpoint_metrics }, + { "webui", params.webui }, + { "webui_settings", meta->json_webui_settings }, + { "chat_template", meta->chat_template }, + { "bos_token", meta->bos_token_str }, + { "eos_token", meta->eos_token_str }, + { "build_info", meta->build_info }, + { "is_sleeping", queue_tasks.is_sleeping() }, + }; + if (params.use_jinja) { + if (!meta->chat_template_tool_use.empty()) { + props["chat_template_tool_use"] = meta->chat_template_tool_use; + } + } res->ok(props); return res; }; this->post_props = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server); + auto res = create_response(); if (!params.endpoint_props) { res->error(format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED)); return res; @@ -3260,20 +3338,16 @@ void server_routes::init_routes() { }; this->get_api_show = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server); - bool has_mtmd = ctx_server.mctx != nullptr; + auto res = create_response(); json data = { - { - "template", common_chat_templates_source(ctx_server.chat_templates.get()), - }, { "model_info", { - { "llama.context_length", ctx_server.get_slot_n_ctx() }, + { "llama.context_length", meta->slot_n_ctx }, } }, {"modelfile", ""}, {"parameters", ""}, - {"template", common_chat_templates_source(ctx_server.chat_templates.get())}, + {"template", meta->chat_template}, {"details", { {"parent_model", ""}, {"format", "gguf"}, @@ -3283,7 +3357,7 @@ void server_routes::init_routes() { {"quantization_level", ""} }}, {"model_info", ""}, - {"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})} + {"capabilities", meta->has_mtmd ? json({"completion","multimodal"}) : json({"completion"})} }; res->ok(data); @@ -3291,7 +3365,7 @@ void server_routes::init_routes() { }; this->post_infill = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); + auto res = create_response(); // check model compatibility std::string err; if (llama_vocab_fim_pre(ctx_server.vocab) == LLAMA_TOKEN_NULL) { @@ -3352,54 +3426,48 @@ void server_routes::init_routes() { data.at("input_prefix"), data.at("input_suffix"), data.at("input_extra"), - ctx_server.params_base.n_batch, - ctx_server.params_base.n_predict, - ctx_server.get_slot_n_ctx(), - ctx_server.params_base.spm_infill, + params.n_batch, + params.n_predict, + meta->slot_n_ctx, + params.spm_infill, tokenized_prompts[0].get_text_tokens() // TODO: this could maybe be multimodal. ); std::vector files; // dummy return handle_completions_impl( - std::move(res), - ctx_server, + req, SERVER_TASK_TYPE_INFILL, data, files, - req.should_stop, TASK_RESPONSE_TYPE_NONE); // infill is not OAI compatible }; this->post_completions = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); + auto res = create_response(); std::vector files; // dummy const json body = json::parse(req.body); return handle_completions_impl( - std::move(res), - ctx_server, + req, SERVER_TASK_TYPE_COMPLETION, body, files, - req.should_stop, TASK_RESPONSE_TYPE_NONE); }; this->post_completions_oai = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); + auto res = create_response(); std::vector files; // dummy const json body = json::parse(req.body); return handle_completions_impl( - std::move(res), - ctx_server, + req, SERVER_TASK_TYPE_COMPLETION, body, files, - req.should_stop, TASK_RESPONSE_TYPE_OAI_CMPL); }; this->post_chat_completions = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); + auto res = create_response(); std::vector files; json body = json::parse(req.body); json body_parsed = oaicompat_chat_params_parse( @@ -3407,17 +3475,15 @@ void server_routes::init_routes() { ctx_server.oai_parser_opt, files); return handle_completions_impl( - std::move(res), - ctx_server, + req, SERVER_TASK_TYPE_COMPLETION, body_parsed, files, - req.should_stop, TASK_RESPONSE_TYPE_OAI_CHAT); }; this->post_anthropic_messages = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); + auto res = create_response(); std::vector files; json body = convert_anthropic_to_oai(json::parse(req.body)); json body_parsed = oaicompat_chat_params_parse( @@ -3425,17 +3491,15 @@ void server_routes::init_routes() { ctx_server.oai_parser_opt, files); return handle_completions_impl( - std::move(res), - ctx_server, + req, SERVER_TASK_TYPE_COMPLETION, body_parsed, files, - req.should_stop, TASK_RESPONSE_TYPE_ANTHROPIC); }; this->post_anthropic_count_tokens = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); + auto res = create_response(); std::vector files; json body = convert_anthropic_to_oai(json::parse(req.body)); json body_parsed = oaicompat_chat_params_parse( @@ -3445,14 +3509,13 @@ void server_routes::init_routes() { json prompt = body_parsed.at("prompt"); llama_tokens tokens = tokenize_mixed(ctx_server.vocab, prompt, true, true); - res->ok({{"input_tokens", static_cast(tokens.size())}}); return res; }; // same with handle_chat_completions, but without inference part this->post_apply_template = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); + auto res = create_response(); std::vector files; // dummy, unused json body = json::parse(req.body); json data = oaicompat_chat_params_parse( @@ -3463,27 +3526,26 @@ void server_routes::init_routes() { return res; }; - // TODO: this endpoint is unsafe to access during model reloading (i.e. wake up from sleeping) - // how to make it work even during load_model()? this->get_models = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server); - json model_meta = nullptr; - if (is_ready()) { - model_meta = ctx_server.json_server_model_meta; - } - bool has_mtmd = ctx_server.mctx != nullptr; + auto res = create_response(true); + + // this endpoint can be accessed during sleeping + // the next LOC is to avoid someone accidentally use ctx_server + bool server_ctx; // do NOT delete this line + GGML_UNUSED(server_ctx); + json models = { {"models", { { - {"name", ctx_server.model_name}, - {"model", ctx_server.model_name}, + {"name", meta->model_name}, + {"model", meta->model_name}, {"modified_at", ""}, {"size", ""}, {"digest", ""}, // dummy value, llama.cpp does not support managing model file's hash {"type", "model"}, {"description", ""}, {"tags", {""}}, - {"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})}, + {"capabilities", meta->has_mtmd ? json({"completion","multimodal"}) : json({"completion"})}, {"parameters", ""}, {"details", { {"parent_model", ""}, @@ -3498,11 +3560,18 @@ void server_routes::init_routes() { {"object", "list"}, {"data", { { - {"id", ctx_server.model_name}, + {"id", meta->model_name}, {"object", "model"}, {"created", std::time(0)}, {"owned_by", "llamacpp"}, - {"meta", model_meta}, + {"meta", { + {"vocab_type", meta->model_vocab_type}, + {"n_vocab", meta->model_vocab_n_tokens}, + {"n_ctx_train", meta->model_n_ctx_train}, + {"n_embd", meta->model_n_embd_inp}, + {"n_params", meta->model_n_params}, + {"size", meta->model_size}, + }}, }, }} }; @@ -3512,7 +3581,7 @@ void server_routes::init_routes() { }; this->post_tokenize = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); + auto res = create_response(); const json body = json::parse(req.body); json tokens_response = json::array(); if (body.count("content") != 0) { @@ -3524,7 +3593,7 @@ void server_routes::init_routes() { if (with_pieces) { for (const auto& token : tokens) { - std::string piece = common_token_to_piece(ctx_server.ctx, token); + std::string piece = common_token_to_piece(ctx_server.vocab, token); json piece_json; // Check if the piece is valid UTF-8 @@ -3553,13 +3622,13 @@ void server_routes::init_routes() { }; this->post_detokenize = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); + auto res = create_response(); const json body = json::parse(req.body); std::string content; if (body.count("tokens") != 0) { const llama_tokens tokens = body.at("tokens"); - content = tokens_to_str(ctx_server.ctx, tokens); + content = tokens_to_str(ctx_server.vocab, tokens); } res->ok(json{{"content", std::move(content)}}); @@ -3575,8 +3644,8 @@ void server_routes::init_routes() { }; this->post_rerank = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - if (!ctx_server.params_base.embedding || ctx_server.params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) { + auto res = create_response(); + if (!params.embedding || params.pooling_type != LLAMA_POOLING_TYPE_RANK) { res->error(format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); return res; } @@ -3611,15 +3680,14 @@ void server_routes::init_routes() { // create and queue the task json responses = json::array(); - server_response_reader rd = ctx_server.get_response_reader(); + auto & rd = res->rd; { std::vector tasks; tasks.reserve(documents.size()); for (size_t i = 0; i < documents.size(); i++) { auto tmp = format_prompt_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]); server_task task = server_task(SERVER_TASK_TYPE_RERANK); - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; + task.id = rd.get_new_id(); task.tokens = std::move(tmp); tasks.push_back(std::move(task)); } @@ -3645,7 +3713,7 @@ void server_routes::init_routes() { // write JSON response json root = format_response_rerank( body, - ctx_server.model_name, + meta->model_name, responses, is_tei_format, documents, @@ -3655,57 +3723,47 @@ void server_routes::init_routes() { return res; }; - this->get_lora_adapters = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server); - json result = json::array(); - const auto & loras = ctx_server.params_base.lora_adapters; - for (size_t i = 0; i < loras.size(); ++i) { - auto & lora = loras[i]; - json entry = { - {"id", i}, - {"path", lora.path}, - {"scale", lora.scale}, - {"task_name", lora.task_name}, - {"prompt_prefix", lora.prompt_prefix}, - }; - std::string alora_invocation_string = ""; - const uint64_t n_alora_tokens = llama_adapter_get_alora_n_invocation_tokens(lora.ptr); - std::vector alora_invocation_tokens; - if (n_alora_tokens) { - const llama_token * alora_tokens = llama_adapter_get_alora_invocation_tokens(lora.ptr); - for (uint64_t i = 0; i < n_alora_tokens; ++i) { - alora_invocation_string += common_token_to_piece(ctx_server.ctx, alora_tokens[i]); - alora_invocation_tokens.push_back(alora_tokens[i]); - } - entry["alora_invocation_string"] = alora_invocation_string; - entry["alora_invocation_tokens"] = alora_invocation_tokens; - } - result.push_back(std::move(entry)); + this->get_lora_adapters = [this](const server_http_req & req) { + auto res = create_response(); + + auto & rd = res->rd; + { + server_task task(SERVER_TASK_TYPE_GET_LORA); + task.id = rd.get_new_id(); + rd.post_task(std::move(task)); } - res->ok(result); + + // get the result + server_task_result_ptr result = rd.next(req.should_stop); + + if (result->is_error()) { + res->error(result->to_json()); + return res; + } + + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + res->ok(result->to_json()); return res; }; this->post_lora_adapters = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); + auto res = create_response(); const json body = json::parse(req.body); if (!body.is_array()) { res->error(format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST)); return res; } - int task_id = ctx_server.queue_tasks.get_new_id(); + auto & rd = res->rd; { server_task task(SERVER_TASK_TYPE_SET_LORA); - task.id = task_id; - task.set_lora = parse_lora_request(ctx_server.params_base.lora_adapters, body); - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); + task.id = rd.get_new_id(); + task.set_lora = parse_lora_request(body); + rd.post_task(std::move(task)); } // get the result - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); + server_task_result_ptr result = rd.next(req.should_stop); if (result->is_error()) { res->error(result->to_json()); @@ -3719,7 +3777,7 @@ void server_routes::init_routes() { } std::unique_ptr server_routes::handle_slots_save(const server_http_req & req, int id_slot) { - auto res = std::make_unique(ctx_server); + auto res = create_response(); const json request_data = json::parse(req.body); std::string filename = request_data.at("filename"); if (!fs_validate_filename(filename)) { @@ -3728,21 +3786,17 @@ std::unique_ptr server_routes::handle_slots_save(const ser } std::string filepath = params.slot_save_path + filename; - int task_id = ctx_server.queue_tasks.get_new_id(); + auto & rd = res->rd; { server_task task(SERVER_TASK_TYPE_SLOT_SAVE); - task.id = task_id; + task.id = rd.get_new_id(); task.slot_action.slot_id = id_slot; task.slot_action.filename = filename; task.slot_action.filepath = filepath; - - // TODO: use server_response_reader - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); + rd.post_task(std::move(task)); } - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); + server_task_result_ptr result = rd.next(req.should_stop); if (result->is_error()) { res->error(result->to_json()); @@ -3754,7 +3808,7 @@ std::unique_ptr server_routes::handle_slots_save(const ser } std::unique_ptr server_routes::handle_slots_restore(const server_http_req & req, int id_slot) { - auto res = std::make_unique(ctx_server); + auto res = create_response(); const json request_data = json::parse(req.body); std::string filename = request_data.at("filename"); if (!fs_validate_filename(filename)) { @@ -3763,21 +3817,17 @@ std::unique_ptr server_routes::handle_slots_restore(const } std::string filepath = params.slot_save_path + filename; - int task_id = ctx_server.queue_tasks.get_new_id(); + auto & rd = res->rd; { server_task task(SERVER_TASK_TYPE_SLOT_RESTORE); - task.id = task_id; + task.id = rd.get_new_id(); task.slot_action.slot_id = id_slot; task.slot_action.filename = filename; task.slot_action.filepath = filepath; - - // TODO: use server_response_reader - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); + rd.post_task(std::move(task)); } - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); + server_task_result_ptr result = rd.next(req.should_stop); if (result->is_error()) { res->error(result->to_json()); @@ -3789,21 +3839,17 @@ std::unique_ptr server_routes::handle_slots_restore(const return res; } -std::unique_ptr server_routes::handle_slots_erase(const server_http_req &, int id_slot) { - auto res = std::make_unique(ctx_server); - int task_id = ctx_server.queue_tasks.get_new_id(); +std::unique_ptr server_routes::handle_slots_erase(const server_http_req & req, int id_slot) { + auto res = create_response(); + auto & rd = res->rd; { server_task task(SERVER_TASK_TYPE_SLOT_ERASE); - task.id = task_id; + task.id = rd.get_new_id(); task.slot_action.slot_id = id_slot; - - // TODO: use server_response_reader - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); + rd.post_task(std::move(task)); } - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); + server_task_result_ptr result = rd.next(req.should_stop); if (result->is_error()) { res->error(result->to_json()); @@ -3816,13 +3862,13 @@ std::unique_ptr server_routes::handle_slots_erase(const se } std::unique_ptr server_routes::handle_embeddings_impl(const server_http_req & req, task_response_type res_type) { - auto res = std::make_unique(ctx_server); - if (!ctx_server.params_base.embedding) { + auto res = create_response(); + if (!params.embedding) { res->error(format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); return res; } - if (res_type != TASK_RESPONSE_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { + if (res_type != TASK_RESPONSE_TYPE_NONE && meta->pooling_type == LLAMA_POOLING_TYPE_NONE) { res->error(format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); return res; } @@ -3843,7 +3889,7 @@ std::unique_ptr server_routes::handle_embeddings_impl(cons bool use_base64 = false; if (body.count("encoding_format") != 0) { - const std::string& format = body.at("encoding_format"); + const std::string & format = body.at("encoding_format"); if (format == "base64") { use_base64 = true; } else if (format != "float") { @@ -3864,21 +3910,20 @@ std::unique_ptr server_routes::handle_embeddings_impl(cons int embd_normalize = 2; // default to Euclidean/L2 norm if (body.count("embd_normalize") != 0) { embd_normalize = body.at("embd_normalize"); - if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { - SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n", llama_pooling_type(ctx_server.ctx)); + if (meta->pooling_type == LLAMA_POOLING_TYPE_NONE) { + SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n", meta->pooling_type); } } // create and queue the task json responses = json::array(); - server_response_reader rd = ctx_server.get_response_reader(); + auto & rd = res->rd; { std::vector tasks; for (size_t i = 0; i < tokenized_prompts.size(); i++) { server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; + task.id = rd.get_new_id(); task.tokens = std::move(tokenized_prompts[i]); // OAI-compat @@ -3908,7 +3953,7 @@ std::unique_ptr server_routes::handle_embeddings_impl(cons // write JSON response json root = res_type == TASK_RESPONSE_TYPE_OAI_EMBD - ? format_embeddings_response_oaicompat(body, ctx_server.model_name, responses, use_base64) + ? format_embeddings_response_oaicompat(body, meta->model_name, responses, use_base64) : json(responses); res->ok(root); return res; diff --git a/tools/server/server-context.h b/tools/server/server-context.h index a56be7b8e7..09bec15ae1 100644 --- a/tools/server/server-context.h +++ b/tools/server/server-context.h @@ -9,11 +9,35 @@ struct server_context_impl; // private implementation -struct server_context_info { +struct server_context_meta { std::string build_info; std::string model_name; + std::string model_path; + bool has_mtmd; bool has_inp_image; bool has_inp_audio; + json json_webui_settings; + int slot_n_ctx; + enum llama_pooling_type pooling_type; + + // chat template + std::string chat_template; + std::string chat_template_tool_use; + + // tokens + std::string bos_token_str; + std::string eos_token_str; + llama_token fim_pre_token; + llama_token fim_sub_token; + llama_token fim_mid_token; + + // model meta + enum llama_vocab_type model_vocab_type; + int32_t model_vocab_n_tokens; + int32_t model_n_ctx_train; + int32_t model_n_embd_inp; + uint64_t model_n_params; + uint64_t model_size; }; struct server_context { @@ -33,14 +57,15 @@ struct server_context { void terminate(); // get the underlaying llama_context, can return nullptr if sleeping + // not thread-safe, should only be used from the main thread llama_context * get_llama_context() const; // get a new response reader, used by CLI application server_response_reader get_response_reader(); - // get server info - // used by CLI application - server_context_info get_info() const; + // get server metadata (read-only), can only be called after load_model() + // not thread-safe, should only be used from the main thread + server_context_meta get_meta() const; }; @@ -48,13 +73,17 @@ struct server_context { struct server_res_generator; struct server_routes { - server_routes(const common_params & params, server_context & ctx_server, std::function is_ready = []() { return true; }) - : params(params), ctx_server(*ctx_server.impl), is_ready(is_ready) { - init_routes(); - } + server_routes(const common_params & params, server_context & ctx_server); void init_routes(); + + // note: this is not thread-safe and can only when ctx_http.is_ready is false + void update_meta(const server_context & ctx_server) { + this->meta = std::make_unique(ctx_server.get_meta()); + } + // handlers using lambda function, so that they can capture `this` without `std::bind` + // they won't be called until ctx_http.is_ready is set to true server_http_context::handler_t get_health; server_http_context::handler_t get_metrics; server_http_context::handler_t get_slots; @@ -78,13 +107,24 @@ struct server_routes { server_http_context::handler_t get_lora_adapters; server_http_context::handler_t post_lora_adapters; private: - // TODO: move these outside of server_routes? + std::unique_ptr handle_completions_impl( + const server_http_req & req, + server_task_type type, + const json & data, + const std::vector & files, + task_response_type res_type); std::unique_ptr handle_slots_save(const server_http_req & req, int id_slot); std::unique_ptr handle_slots_restore(const server_http_req & req, int id_slot); std::unique_ptr handle_slots_erase(const server_http_req &, int id_slot); std::unique_ptr handle_embeddings_impl(const server_http_req & req, task_response_type res_type); + // using unique_ptr to allow late initialization of const + std::unique_ptr meta; + const common_params & params; - server_context_impl & ctx_server; - std::function is_ready; + const server_context_impl & ctx_server; + + server_queue & queue_tasks; + server_response & queue_results; + std::unique_ptr create_response(bool bypass_sleep = false); }; diff --git a/tools/server/server-http.cpp b/tools/server/server-http.cpp index 622505714c..5d67e5722d 100644 --- a/tools/server/server-http.cpp +++ b/tools/server/server-http.cpp @@ -177,12 +177,11 @@ bool server_http_context::init(const common_params & params) { if (!ready) { auto tmp = string_split(req.path, '.'); if (req.path == "/" || tmp.back() == "html") { - res.set_content(reinterpret_cast(loading_html), loading_html_len, "text/html; charset=utf-8"); res.status = 503; - } else if (req.path == "/models" || req.path == "/v1/models" || req.path == "/api/tags") { - // allow the models endpoint to be accessed during loading - return true; + res.set_content(reinterpret_cast(loading_html), loading_html_len, "text/html; charset=utf-8"); } else { + // no endpoints is allowed to be accessed when the server is not ready + // this is to prevent any data races or inconsistent states res.status = 503; res.set_content( safe_json_to_str(json { @@ -334,12 +333,16 @@ static std::map get_headers(const httplib::Request & r return headers; } -static void process_handler_response(server_http_res_ptr & response, httplib::Response & res) { +// using unique_ptr for request to allow safe capturing in lambdas +using server_http_req_ptr = std::unique_ptr; + +static void process_handler_response(server_http_req_ptr && request, server_http_res_ptr & response, httplib::Response & res) { if (response->is_stream()) { res.status = response->status; set_headers(res, response->headers); std::string content_type = response->content_type; // convert to shared_ptr as both chunked_content_provider() and on_complete() need to use it + std::shared_ptr q_ptr = std::move(request); std::shared_ptr r_ptr = std::move(response); const auto chunked_content_provider = [response = r_ptr](size_t, httplib::DataSink & sink) -> bool { std::string chunk; @@ -355,8 +358,9 @@ static void process_handler_response(server_http_res_ptr & response, httplib::Re } return has_next; }; - const auto on_complete = [response = r_ptr](bool) mutable { + const auto on_complete = [request = q_ptr, response = r_ptr](bool) mutable { response.reset(); // trigger the destruction of the response object + request.reset(); // trigger the destruction of the request object }; res.set_chunked_content_provider(content_type, chunked_content_provider, on_complete); } else { @@ -368,27 +372,29 @@ static void process_handler_response(server_http_res_ptr & response, httplib::Re void server_http_context::get(const std::string & path, const server_http_context::handler_t & handler) const { pimpl->srv->Get(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) { - server_http_res_ptr response = handler(server_http_req{ + server_http_req_ptr request = std::make_unique(server_http_req{ get_params(req), get_headers(req), req.path, req.body, req.is_connection_closed }); - process_handler_response(response, res); + server_http_res_ptr response = handler(*request); + process_handler_response(std::move(request), response, res); }); } void server_http_context::post(const std::string & path, const server_http_context::handler_t & handler) const { pimpl->srv->Post(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) { - server_http_res_ptr response = handler(server_http_req{ + server_http_req_ptr request = std::make_unique(server_http_req{ get_params(req), get_headers(req), req.path, req.body, req.is_connection_closed }); - process_handler_response(response, res); + server_http_res_ptr response = handler(*request); + process_handler_response(std::move(request), response, res); }); } diff --git a/tools/server/server-queue.cpp b/tools/server/server-queue.cpp index 835938bfc2..9a6ba560a3 100644 --- a/tools/server/server-queue.cpp +++ b/tools/server/server-queue.cpp @@ -325,23 +325,25 @@ void server_response::terminate() { // server_response_reader // -void server_response_reader::post_task(server_task && task) { +void server_response_reader::post_task(server_task && task, bool front) { GGML_ASSERT(id_tasks.empty() && "post_task() can only be called once per reader"); + task.index = 0; id_tasks.insert(task.id); states.push_back(task.create_state()); queue_results.add_waiting_task_id(task.id); - queue_tasks.post(std::move(task)); + queue_tasks.post(std::move(task), front); } -void server_response_reader::post_tasks(std::vector && tasks) { +void server_response_reader::post_tasks(std::vector && tasks, bool front) { GGML_ASSERT(id_tasks.empty() && "post_tasks() can only be called once per reader"); id_tasks = server_task::get_list_id(tasks); states.reserve(tasks.size()); for (size_t i = 0; i < tasks.size(); i++) { + tasks[i].index = i; states.push_back(tasks[i].create_state()); } queue_results.add_waiting_tasks(tasks); - queue_tasks.post(std::move(tasks)); + queue_tasks.post(std::move(tasks), front); } bool server_response_reader::has_next() const { @@ -367,7 +369,7 @@ server_task_result_ptr server_response_reader::next(const std::function } if (!states.empty()) { // update the generation state if needed - size_t idx = result->get_index(); + const size_t idx = result->index; GGML_ASSERT(idx < states.size()); result->update(states[idx]); } @@ -383,6 +385,7 @@ server_task_result_ptr server_response_reader::next(const std::function server_response_reader::batch_response server_response_reader::wait_for_all(const std::function & should_stop) { batch_response batch_res; + batch_res.results.clear(); batch_res.results.resize(id_tasks.size()); while (has_next()) { auto res = next(should_stop); @@ -394,7 +397,7 @@ server_response_reader::batch_response server_response_reader::wait_for_all(cons batch_res.error = std::move(res); return batch_res; } - const size_t idx = res->get_index(); + const size_t idx = res->index; GGML_ASSERT(idx < batch_res.results.size() && "index out of range"); GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received"); batch_res.results[idx] = std::move(res); diff --git a/tools/server/server-queue.h b/tools/server/server-queue.h index 8ac37a20f6..3798aa299e 100644 --- a/tools/server/server-queue.h +++ b/tools/server/server-queue.h @@ -5,6 +5,7 @@ #include #include #include +#include #include // struct for managing server tasks @@ -173,8 +174,10 @@ struct server_response_reader { int get_new_id() { return queue_tasks.get_new_id(); } - void post_task(server_task && task); - void post_tasks(std::vector && tasks); + + // if front = true, the task will be posted to the front of the queue (high priority) + void post_task(server_task && task, bool front = false); + void post_tasks(std::vector && tasks, bool front = false); bool has_next() const; // return nullptr if should_stop() is true before receiving a result diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 337895a5ef..82a84f9616 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -32,8 +32,8 @@ json task_params::to_json(bool only_metrics) const { } json lora = json::array(); - for (size_t i = 0; i < this->lora.size(); ++i) { - lora.push_back({{"id", i}, {"scale", this->lora[i].scale}}); + for (auto & it : this->lora) { + lora.push_back({{"id", it.first}, {"scale", it.second}}); } if (only_metrics) { @@ -147,12 +147,10 @@ json task_params::to_json(bool only_metrics) const { // task_params server_task::params_from_json_cmpl( - const llama_context * ctx, + const llama_vocab * vocab, const common_params & params_base, + const int n_ctx_slot, const json & data) { - const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_model_get_vocab(model); - task_params params; // Sampling parameter defaults are loaded from the global server context (but individual requests can still them) @@ -229,12 +227,12 @@ task_params server_task::params_from_json_cmpl( if (data.contains("lora")) { if (data.at("lora").is_array()) { - params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora")); + params.lora = parse_lora_request(data.at("lora")); } else { throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields"); } } else { - params.lora = params_base.lora_adapters; + params.lora = {}; } // TODO: add more sanity checks for the input parameters @@ -249,11 +247,11 @@ task_params server_task::params_from_json_cmpl( if (params.sampling.penalty_last_n == -1) { // note: should be the slot's context and not the full context, but it's ok - params.sampling.penalty_last_n = llama_n_ctx(ctx); + params.sampling.penalty_last_n = n_ctx_slot; } if (params.sampling.dry_penalty_last_n == -1) { - params.sampling.dry_penalty_last_n = llama_n_ctx(ctx); + params.sampling.dry_penalty_last_n = n_ctx_slot; } if (params.sampling.dry_base < 1.0f) { @@ -1159,7 +1157,7 @@ json server_task_result_rerank::to_json() { json server_task_result_cmpl_partial::to_json_anthropic() { json events = json::array(); bool first = (n_decoded == 1); - static bool text_block_started = false; + bool text_block_started = false; if (first) { text_block_started = false; @@ -1330,6 +1328,30 @@ json server_task_result_slot_erase::to_json() { }; } +// +// server_task_result_get_lora +// + +json server_task_result_get_lora::to_json() { + json result = json::array(); + for (size_t i = 0; i < loras.size(); ++i) { + auto & lora = loras[i]; + json entry = { + {"id", i}, + {"path", lora.info.path}, + {"scale", lora.info.scale}, + {"task_name", lora.info.task_name}, + {"prompt_prefix", lora.info.prompt_prefix}, + }; + if (!lora.alora_invocation_tokens.empty()) { + entry["alora_invocation_string"] = lora.alora_invocation_string; + entry["alora_invocation_tokens"] = lora.alora_invocation_tokens; + } + result.push_back(std::move(entry)); + } + return result; +} + // // server_task_result_apply_lora // diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 0759094a01..687770de5e 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -6,6 +6,7 @@ #include #include #include +#include // TODO: prevent including the whole server-common.h as we only use server_tokens #include "server-common.h" @@ -23,6 +24,7 @@ enum server_task_type { SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE, + SERVER_TASK_TYPE_GET_LORA, SERVER_TASK_TYPE_SET_LORA, }; @@ -60,7 +62,7 @@ struct task_params { int64_t t_max_prompt_ms = -1; // TODO: implement int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit - std::vector lora; + std::map lora; // mapping adapter ID -> scale std::vector antiprompt; std::vector response_fields; @@ -105,8 +107,10 @@ struct task_result_state { }; struct server_task { - int id = -1; // to be filled by server_queue - int index = -1; // used when there are multiple prompts (batch request) + int id = -1; // to be filled by server_queue + + // TODO @ngxson : remove this field and implement a mapping task_id -> idx in the response_reader + size_t index = 0; // used when there are multiple prompts (batch request) // used by SERVER_TASK_TYPE_CANCEL int id_target = -1; @@ -138,7 +142,7 @@ struct server_task { bool metrics_reset_bucket = false; // used by SERVER_TASK_TYPE_SET_LORA - std::vector set_lora; + std::map set_lora; // mapping adapter ID -> scale server_task() = default; @@ -149,9 +153,10 @@ struct server_task { } static task_params params_from_json_cmpl( - const llama_context * ctx, - const common_params & params_base, - const json & data); + const llama_vocab * vocab, + const common_params & params_base, + const int n_ctx_slot, + const json & data); // utility function static std::unordered_set get_list_id(const std::vector & tasks) { @@ -162,10 +167,9 @@ struct server_task { return ids; } - server_task create_child(int id_parent, int id_child, int idx) const { + server_task create_child(int id_parent, int id_child) const { server_task copy; copy.id = id_child; - copy.index = idx; copy.id_parent = id_parent; copy.params = params; copy.type = type; @@ -212,6 +216,10 @@ struct result_prompt_progress { struct server_task_result { int id = -1; int id_slot = -1; + + // TODO @ngxson : remove this field and implement a mapping task_id -> idx in the response_reader + size_t index = 0; // to be used for batched tasks + virtual bool is_error() { // only used by server_task_result_error return false; @@ -220,9 +228,6 @@ struct server_task_result { // only used by server_task_result_cmpl_* return true; } - virtual int get_index() { - return -1; - } virtual void update(task_result_state &) { // only used by server_task_result_cmpl_* } @@ -255,8 +260,6 @@ struct completion_token_output { }; struct server_task_result_cmpl_final : server_task_result { - int index = 0; - std::string content; llama_tokens tokens; @@ -289,10 +292,6 @@ struct server_task_result_cmpl_final : server_task_result { std::vector oaicompat_msg_diffs; // to be populated by update() bool is_updated = false; - virtual int get_index() override { - return index; - } - virtual bool is_stop() override { return true; // in stream mode, final responses are considered stop } @@ -318,8 +317,6 @@ struct server_task_result_cmpl_final : server_task_result { }; struct server_task_result_cmpl_partial : server_task_result { - int index = 0; - std::string content; llama_tokens tokens; @@ -340,10 +337,6 @@ struct server_task_result_cmpl_partial : server_task_result { std::vector oaicompat_msg_diffs; // to be populated by update() bool is_updated = false; - virtual int get_index() override { - return index; - } - virtual bool is_stop() override { return false; // in stream mode, partial responses are not considered stop } @@ -365,7 +358,6 @@ struct server_task_result_cmpl_partial : server_task_result { }; struct server_task_result_embd : server_task_result { - int index = 0; std::vector> embedding; int32_t n_tokens; @@ -373,10 +365,6 @@ struct server_task_result_embd : server_task_result { // response formatting task_response_type res_type = TASK_RESPONSE_TYPE_NONE; - virtual int get_index() override { - return index; - } - virtual json to_json() override; json to_json_non_oaicompat(); @@ -385,20 +373,14 @@ struct server_task_result_embd : server_task_result { }; struct server_task_result_rerank : server_task_result { - int index = 0; float score = -1e6; int32_t n_tokens; - virtual int get_index() override { - return index; - } - virtual json to_json() override; }; struct server_task_result_error : server_task_result { - int index = 0; error_type err_type = ERROR_TYPE_SERVER; std::string err_msg; @@ -460,6 +442,17 @@ struct server_task_result_slot_erase : server_task_result { virtual json to_json() override; }; +struct server_task_result_get_lora : server_task_result { + struct lora { + common_adapter_lora_info info; + std::string alora_invocation_string; + llama_tokens alora_invocation_tokens; + }; + std::vector loras; + + virtual json to_json() override; +}; + struct server_task_result_apply_lora : server_task_result { virtual json to_json() override; }; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index ff650ab2ec..0fbc7b6d35 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -119,7 +119,7 @@ int main(int argc, char ** argv, char ** envp) { // // register API routes - server_routes routes(params, ctx_server, [&ctx_http]() { return ctx_http.is_ready.load(); }); + server_routes routes(params, ctx_server); bool is_router_server = params.model.path.empty(); std::optional models_routes{}; @@ -252,6 +252,7 @@ int main(int argc, char ** argv, char ** envp) { return 1; } + routes.update_meta(ctx_server); ctx_http.is_ready.store(true); LOG_INF("%s: model loaded\n", __func__); diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py index 08b5265d48..73bd8add07 100644 --- a/tools/server/tests/unit/test_chat_completion.py +++ b/tools/server/tests/unit/test_chat_completion.py @@ -434,8 +434,8 @@ def test_context_size_exceeded_stream(): @pytest.mark.parametrize( "n_batch,batch_count,reuse_cache", [ - (64, 3, False), - (64, 1, True), + (64, 4, False), + (64, 2, True), ] ) def test_return_progress(n_batch, batch_count, reuse_cache): @@ -462,10 +462,18 @@ def test_return_progress(n_batch, batch_count, reuse_cache): res = make_cmpl_request() last_progress = None total_batch_count = 0 + for data in res: cur_progress = data.get("prompt_progress", None) if cur_progress is None: continue + if total_batch_count == 0: + # first progress report must have n_cache == n_processed + assert cur_progress["total"] > 0 + assert cur_progress["cache"] == cur_progress["processed"] + if reuse_cache: + # when reusing cache, we expect some cached tokens + assert cur_progress["cache"] > 0 if last_progress is not None: assert cur_progress["total"] == last_progress["total"] assert cur_progress["cache"] == last_progress["cache"] @@ -473,6 +481,7 @@ def test_return_progress(n_batch, batch_count, reuse_cache): total_batch_count += 1 last_progress = cur_progress + # last progress should indicate completion (all tokens processed) assert last_progress is not None assert last_progress["total"] > 0 assert last_progress["processed"] == last_progress["total"] diff --git a/tools/server/webui/src/lib/stores/settings.svelte.ts b/tools/server/webui/src/lib/stores/settings.svelte.ts index e163833bfb..cda940ba7e 100644 --- a/tools/server/webui/src/lib/stores/settings.svelte.ts +++ b/tools/server/webui/src/lib/stores/settings.svelte.ts @@ -294,15 +294,14 @@ class SettingsStore { * This sets up the default values from /props endpoint */ syncWithServerDefaults(): void { - const serverParams = serverStore.defaultParams; - if (!serverParams) { - console.warn('No server parameters available for initialization'); + const propsDefaults = this.getServerDefaults(); + + if (Object.keys(propsDefaults).length === 0) { + console.warn('No server defaults available for initialization'); return; } - const propsDefaults = this.getServerDefaults(); - for (const [key, propsValue] of Object.entries(propsDefaults)) { const currentValue = getConfigValue(this.config, key); diff --git a/tools/server/webui/src/routes/+layout.svelte b/tools/server/webui/src/routes/+layout.svelte index a14dfb633c..095827b9ca 100644 --- a/tools/server/webui/src/routes/+layout.svelte +++ b/tools/server/webui/src/routes/+layout.svelte @@ -119,7 +119,7 @@ $effect(() => { const serverProps = serverStore.props; - if (serverProps?.default_generation_settings?.params) { + if (serverProps) { settingsStore.syncWithServerDefaults(); } });