diff --git a/.github/ISSUE_TEMPLATE/010-bug-compilation.yml b/.github/ISSUE_TEMPLATE/010-bug-compilation.yml
index feb0d51205..c106f47a25 100644
--- a/.github/ISSUE_TEMPLATE/010-bug-compilation.yml
+++ b/.github/ISSUE_TEMPLATE/010-bug-compilation.yml
@@ -8,7 +8,8 @@ body:
value: >
Thanks for taking the time to fill out this bug report!
This issue template is intended for bug reports where the compilation of llama.cpp fails.
- Before opening an issue, please confirm that the compilation still fails with `-DGGML_CCACHE=OFF`.
+ Before opening an issue, please confirm that the compilation still fails
+ after recreating the CMake build directory and with `-DGGML_CCACHE=OFF`.
If the compilation succeeds with ccache disabled you should be able to permanently fix the issue
by clearing `~/.cache/ccache` (on Linux).
- type: textarea
diff --git a/.github/ISSUE_TEMPLATE/011-bug-results.yml b/.github/ISSUE_TEMPLATE/011-bug-results.yml
index b815e70a8d..31202dfa83 100644
--- a/.github/ISSUE_TEMPLATE/011-bug-results.yml
+++ b/.github/ISSUE_TEMPLATE/011-bug-results.yml
@@ -98,7 +98,18 @@ body:
label: Relevant log output
description: >
Please copy and paste any relevant log output, including the command that you entered and any generated text.
- This will be automatically formatted into code, so no need for backticks.
- render: shell
+ For very long logs (thousands of lines), preferably upload them as files instead.
+ On Linux you can redirect console output into a file by appending ` > llama.log 2>&1` to your command.
+ value: |
+
+ Logs
+
+
+ ```console
+
+ ```
+
+
+
validations:
required: true
diff --git a/.github/ISSUE_TEMPLATE/019-bug-misc.yml b/.github/ISSUE_TEMPLATE/019-bug-misc.yml
index e1bd08ddd2..8e867e7f60 100644
--- a/.github/ISSUE_TEMPLATE/019-bug-misc.yml
+++ b/.github/ISSUE_TEMPLATE/019-bug-misc.yml
@@ -85,8 +85,19 @@ body:
label: Relevant log output
description: >
If applicable, please copy and paste any relevant log output, including any generated text.
- This will be automatically formatted into code, so no need for backticks.
If you are encountering problems specifically with the `llama_params_fit` module, always upload `--verbose` logs as well.
- render: shell
+ For very long logs (thousands of lines), please upload them as files instead.
+ On Linux you can redirect console output into a file by appending ` > llama.log 2>&1` to your command.
+ value: |
+
+ Logs
+
+
+ ```console
+
+ ```
+
+
+
validations:
required: false
diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml
index 7ca11b1dff..bfd1270716 100644
--- a/.github/workflows/docker.yml
+++ b/.github/workflows/docker.yml
@@ -45,8 +45,7 @@ jobs:
- { tag: "intel", dockerfile: ".devops/intel.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" }
- { tag: "vulkan", dockerfile: ".devops/vulkan.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04" }
- { tag: "s390x", dockerfile: ".devops/s390x.Dockerfile", platforms: "linux/s390x", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04-s390x" }
- # Note: the rocm images are failing due to a compiler error and are disabled until this is fixed to allow the workflow to complete
- #- {tag: "rocm", dockerfile: ".devops/rocm.Dockerfile", platforms: "linux/amd64,linux/arm64", full: true, light: true, server: true, free_disk_space: true }
+ - { tag: "rocm", dockerfile: ".devops/rocm.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" }
steps:
- name: Check out the repo
uses: actions/checkout@v4
diff --git a/common/arg.cpp b/common/arg.cpp
index 87438d8d09..4b92d46f28 100644
--- a/common/arg.cpp
+++ b/common/arg.cpp
@@ -2037,7 +2037,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
if (llama_supports_rpc()) {
add_opt(common_arg(
{"--rpc"}, "SERVERS",
- "comma separated list of RPC servers",
+ "comma separated list of RPC servers (host:port)",
[](common_params & params, const std::string & value) {
add_rpc_devices(value);
GGML_UNUSED(params);
@@ -2157,11 +2157,18 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_N_CPU_MOE_DRAFT"));
+ GGML_ASSERT(params.n_gpu_layers < 0); // string_format would need to be extended for a default >= 0
add_opt(common_arg(
{"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N",
- string_format("max. number of layers to store in VRAM (default: %d)", params.n_gpu_layers),
- [](common_params & params, int value) {
- params.n_gpu_layers = value;
+ string_format("max. number of layers to store in VRAM, either an exact number, 'auto', or 'all' (default: %s)", params.n_gpu_layers == -1 ? "auto" : "all"),
+ [](common_params & params, const std::string & value) {
+ if (value == "auto") {
+ params.n_gpu_layers = -1;
+ } else if (value == "all") {
+ params.n_gpu_layers = -2;
+ } else {
+ params.n_gpu_layers = std::stoi(value);
+ }
if (!llama_supports_gpu_offload()) {
fprintf(stderr, "warning: no usable GPU found, --gpu-layers option will be ignored\n");
fprintf(stderr, "warning: one possible reason is that llama.cpp was compiled without GPU support\n");
@@ -3195,11 +3202,19 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.devices = parse_device_list(value);
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
+ GGML_ASSERT(params.speculative.n_gpu_layers < 0); // string_format would need to be extended for a default >= 0
add_opt(common_arg(
{"-ngld", "--gpu-layers-draft", "--n-gpu-layers-draft"}, "N",
- "number of layers to store in VRAM for the draft model",
- [](common_params & params, int value) {
- params.speculative.n_gpu_layers = value;
+ string_format("max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: %s)",
+ params.speculative.n_gpu_layers == -1 ? "auto" : "all"),
+ [](common_params & params, const std::string & value) {
+ if (value == "auto") {
+ params.speculative.n_gpu_layers = -1;
+ } else if (value == "all") {
+ params.speculative.n_gpu_layers = -2;
+ } else {
+ params.speculative.n_gpu_layers = std::stoi(value);
+ }
if (!llama_supports_gpu_offload()) {
fprintf(stderr, "warning: no usable GPU found, --gpu-layers-draft option will be ignored\n");
fprintf(stderr, "warning: one possible reason is that llama.cpp was compiled without GPU support\n");
diff --git a/common/common.cpp b/common/common.cpp
index acf2ec841d..8d62893370 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -1341,10 +1341,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
mparams.devices = params.devices.data();
}
- if (params.n_gpu_layers != -1) {
- mparams.n_gpu_layers = params.n_gpu_layers;
- }
-
+ mparams.n_gpu_layers = params.n_gpu_layers;
mparams.main_gpu = params.main_gpu;
mparams.split_mode = params.split_mode;
mparams.tensor_split = params.tensor_split;
diff --git a/common/common.h b/common/common.h
index 2145f4f4c2..17023f0f89 100644
--- a/common/common.h
+++ b/common/common.h
@@ -332,7 +332,7 @@ struct common_params {
// offload params
std::vector devices; // devices to use for offloading
- int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
+ int32_t n_gpu_layers = -1; // number of layers to store in VRAM, -1 is auto, <= -2 is all
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
bool fit_params = true; // whether to fit unset model/context parameters to free device memory
diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index 69abb7367d..f893b24c75 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -1696,6 +1696,84 @@ class TextModel(ModelBase):
if template is not None:
self.gguf_writer.add_chat_template(template)
+ def _set_vocab_plamo(self):
+ # PLaMo models use a custom tokenizer with a .jsonl file
+ tokenizer_jsonl_path = self.dir_model / "tokenizer.jsonl"
+ tokenizer_config_path = self.dir_model / "tokenizer_config.json"
+
+ if not tokenizer_jsonl_path.is_file():
+ raise FileNotFoundError(f"PLaMo tokenizer file not found: {tokenizer_jsonl_path}")
+
+ # Load tokenizer config
+ with open(tokenizer_config_path, "r", encoding="utf-8") as f:
+ tokenizer_config = json.load(f)
+
+ # Load tokens from JSONL file (actually a list format)
+ tokens = []
+ scores = []
+ toktypes = []
+
+ with open(tokenizer_jsonl_path, "r", encoding="utf-8") as f:
+ for line_num, line in enumerate(f):
+ if line.strip():
+ token_data = json.loads(line)
+ # Format: [token, score, type, ?, ?, ?, ?]
+ token = token_data[0].encode("utf-8")
+ score = float(token_data[1])
+ token_type_str = token_data[2] if len(token_data) > 2 else "NORMAL"
+
+ tokens.append(token)
+ scores.append(score)
+
+ if token_type_str == "UNKNOWN":
+ toktypes.append(gguf.TokenType.UNKNOWN)
+ elif token_type_str == "CONTROL":
+ toktypes.append(gguf.TokenType.CONTROL)
+ elif token_type_str == "BYTE":
+ toktypes.append(gguf.TokenType.BYTE)
+ else:
+ token_str = token_data[0]
+ if token_str.startswith("<|plamo:") and token_str.endswith("|>"):
+ toktypes.append(gguf.TokenType.CONTROL)
+ else:
+ toktypes.append(gguf.TokenType.NORMAL)
+
+ vocab_size = self.hparams["vocab_size"]
+ if vocab_size > len(tokens):
+ pad_count = vocab_size - len(tokens)
+ logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]")
+ for i in range(1, pad_count + 1):
+ tokens.append(bytes(f"[PAD{i}]", encoding="utf-8"))
+ scores.append(-1000.0)
+ toktypes.append(gguf.TokenType.UNUSED)
+
+ self.gguf_writer.add_tokenizer_model("plamo2")
+ self.gguf_writer.add_tokenizer_pre("default")
+ self.gguf_writer.add_token_list(tokens)
+ self.gguf_writer.add_token_scores(scores)
+ self.gguf_writer.add_token_types(toktypes)
+
+ if "bos_token" in tokenizer_config and tokenizer_config["bos_token"] is not None:
+ token_id = tokens.index(tokenizer_config["bos_token"].encode("utf-8"))
+ self.gguf_writer.add_bos_token_id(token_id)
+ if "eos_token" in tokenizer_config and tokenizer_config["eos_token"] is not None:
+ token_id = tokens.index(tokenizer_config["eos_token"].encode("utf-8"))
+ self.gguf_writer.add_eos_token_id(token_id)
+ if "pad_token" in tokenizer_config and tokenizer_config["pad_token"] is not None:
+ token_id = tokens.index(tokenizer_config["pad_token"].encode("utf-8"))
+ self.gguf_writer.add_pad_token_id(token_id)
+ if "sep_token" in tokenizer_config and tokenizer_config["sep_token"] is not None:
+ token_id = tokens.index(tokenizer_config["sep_token"].encode("utf-8"))
+ self.gguf_writer.add_sep_token_id(token_id)
+ if "unk_token" in tokenizer_config and tokenizer_config["unk_token"] is not None:
+ token_id = tokens.index(tokenizer_config["unk_token"].encode("utf-8"))
+ self.gguf_writer.add_unk_token_id(token_id)
+
+ # Add <|plamo:op|> as EOT to ensure appropriate end of generation
+ self.gguf_writer.add_eot_token_id(4)
+
+ self.gguf_writer.add_add_space_prefix(False)
+
class MmprojModel(ModelBase):
model_type = ModelType.MMPROJ
@@ -4798,87 +4876,7 @@ class Plamo2Model(TextModel):
model_arch = gguf.MODEL_ARCH.PLAMO2
def set_vocab(self):
- # PLaMo 2 uses a custom tokenizer with a .jsonl file
- # We need to handle this specially
- tokenizer_jsonl_path = self.dir_model / "tokenizer.jsonl"
- tokenizer_config_path = self.dir_model / "tokenizer_config.json"
-
- if not tokenizer_jsonl_path.is_file():
- raise FileNotFoundError(f"PLaMo 2 tokenizer file not found: {tokenizer_jsonl_path}")
-
- # Load tokenizer config
- with open(tokenizer_config_path, 'r', encoding='utf-8') as f:
- tokenizer_config = json.load(f)
-
- # Load tokens from JSONL file (actually a list format)
- tokens = []
- scores = []
- toktypes = []
-
- with open(tokenizer_jsonl_path, 'r', encoding='utf-8') as f:
- for line_num, line in enumerate(f):
- if line.strip():
- token_data = json.loads(line)
- # Format: [token, score, type, ?, ?, ?, ?]
- token = token_data[0].encode("utf-8")
- score = float(token_data[1])
- token_type_str = token_data[2] if len(token_data) > 2 else "NORMAL"
-
- tokens.append(token)
- scores.append(score)
-
- # Map token type strings to GGUF token types
- if token_type_str == "UNKNOWN":
- toktypes.append(gguf.TokenType.UNKNOWN)
- elif token_type_str == "CONTROL":
- toktypes.append(gguf.TokenType.CONTROL)
- elif token_type_str == "BYTE":
- toktypes.append(gguf.TokenType.BYTE)
- else:
- # Check for PLaMo-2 special tokens
- token_str = token_data[0]
- if token_str.startswith("<|plamo:") and token_str.endswith("|>"):
- toktypes.append(gguf.TokenType.CONTROL)
- else:
- toktypes.append(gguf.TokenType.NORMAL)
-
- vocab_size = self.hparams["vocab_size"]
- if vocab_size > len(tokens):
- pad_count = vocab_size - len(tokens)
- logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]")
- for i in range(1, pad_count + 1):
- tokens.append(bytes(f"[PAD{i}]", encoding="utf-8"))
- scores.append(-1000.0)
- toktypes.append(gguf.TokenType.UNUSED)
-
- # Use "plamo2" tokenizer type for PLaMo-2's custom Aho-Corasick tokenizer
- self.gguf_writer.add_tokenizer_model("plamo2")
- self.gguf_writer.add_tokenizer_pre("default")
- self.gguf_writer.add_token_list(tokens)
- self.gguf_writer.add_token_scores(scores)
- self.gguf_writer.add_token_types(toktypes)
-
- # Add special tokens from config
- if "bos_token" in tokenizer_config and tokenizer_config["bos_token"] is not None:
- token_id = tokens.index(tokenizer_config["bos_token"].encode("utf-8"))
- self.gguf_writer.add_bos_token_id(token_id)
- if "eos_token" in tokenizer_config and tokenizer_config["eos_token"] is not None:
- token_id = tokens.index(tokenizer_config["eos_token"].encode("utf-8"))
- self.gguf_writer.add_eos_token_id(token_id)
- if "pad_token" in tokenizer_config and tokenizer_config["pad_token"] is not None:
- token_id = tokens.index(tokenizer_config["pad_token"].encode("utf-8"))
- self.gguf_writer.add_pad_token_id(token_id)
- if "sep_token" in tokenizer_config and tokenizer_config["sep_token"] is not None:
- token_id = tokens.index(tokenizer_config["sep_token"].encode("utf-8"))
- self.gguf_writer.add_sep_token_id(token_id)
- if "unk_token" in tokenizer_config and tokenizer_config["unk_token"] is not None:
- token_id = tokens.index(tokenizer_config["unk_token"].encode("utf-8"))
- self.gguf_writer.add_unk_token_id(token_id)
-
- # Add <|plamo:op|> as EOT to ensure appropriate end of generation
- self.gguf_writer.add_eot_token_id(4)
-
- self.gguf_writer.add_add_space_prefix(False)
+ self._set_vocab_plamo()
def set_gguf_parameters(self):
hparams = self.hparams
@@ -4966,6 +4964,56 @@ class Plamo2Model(TextModel):
return [(new_name, data_torch)]
+@ModelBase.register("Plamo3ForCausalLM", "PLaMo3ForCausalLM")
+class Plamo3Model(TextModel):
+ model_arch = gguf.MODEL_ARCH.PLAMO3
+
+ def set_vocab(self):
+ self._set_vocab_plamo()
+
+ tokenizer_config_path = self.dir_model / "tokenizer_config.json"
+ tokenizer_config = {}
+
+ if tokenizer_config_path.is_file():
+ with open(tokenizer_config_path, encoding="utf-8") as f:
+ tokenizer_config = json.load(f)
+
+ chat_template = tokenizer_config.get("chat_template")
+ chat_template_jinja = self.dir_model / "chat_template.jinja"
+
+ if chat_template_jinja.is_file():
+ with open(chat_template_jinja, encoding="utf-8") as f:
+ chat_template = f.read()
+
+ if chat_template:
+ self.gguf_writer.add_chat_template(chat_template)
+
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+ self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
+ if (sliding_window := self.find_hparam(["window_size", "sliding_window"], optional=True)) is not None:
+ self.gguf_writer.add_sliding_window(sliding_window)
+ self.gguf_writer.add_sliding_window_pattern(self.hparams["sliding_window_pattern"])
+ self.gguf_writer.add_rope_freq_base_swa(self.rope_parameters.get("sliding_attention", {"rope_theta": self.hparams.get("rope_local_theta")})["rope_theta"])
+
+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+
+ if name.endswith(".pre_mixer_norm.weight"):
+ data_torch = data_torch + 1.0
+ elif name.endswith(".post_mixer_norm.weight"):
+ data_torch = data_torch + 1.0 / 5
+ elif name.endswith(".pre_mlp_norm.weight"):
+ data_torch = data_torch + 1.0
+ elif name.endswith(".post_mlp_norm.weight"):
+ data_torch = data_torch + 1.0 / (5**1.5)
+ elif name.endswith((".mixer.q_norm.weight", ".mixer.k_norm.weight")):
+ data_torch = data_torch + 1.0
+ elif name.endswith(".norm.weight"):
+ data_torch = data_torch + 1.0
+
+ return [(self.map_tensor_name(name), data_torch)]
+
+
@ModelBase.register("CodeShellForCausalLM")
class CodeShellModel(TextModel):
model_arch = gguf.MODEL_ARCH.CODESHELL
diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt
index 18d117f7cc..cb46c32100 100644
--- a/ggml/CMakeLists.txt
+++ b/ggml/CMakeLists.txt
@@ -430,10 +430,22 @@ if (MSVC)
configure_msvc_target(ggml-cpu-x64)
configure_msvc_target(ggml-cpu-sse42)
configure_msvc_target(ggml-cpu-sandybridge)
+ # __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
+ # skipping ggml-cpu-ivybridge
+ # skipping ggml-cpu-piledriver
configure_msvc_target(ggml-cpu-haswell)
configure_msvc_target(ggml-cpu-skylakex)
+ configure_msvc_target(ggml-cpu-cannonlake)
+ configure_msvc_target(ggml-cpu-cascadelake)
configure_msvc_target(ggml-cpu-icelake)
+ # MSVC 2022 doesn't support BF16 intrinsics without `/arch:AVX10.1` ?!
+ # https://learn.microsoft.com/en-us/cpp/intrinsics/x64-amd64-intrinsics-list?view=msvc-170
+ # https://learn.microsoft.com/en-us/cpp/build/reference/arch-x64?view=msvc-170
+ # skipping ggml-cpu-cooperlake
+ # skipping ggml-cpu-zen4
configure_msvc_target(ggml-cpu-alderlake)
+ # MSVC doesn't support AMX
+ # skipping ggml-cpu-sapphirerapids
if (GGML_BUILD_EXAMPLES)
configure_msvc_target(common-ggml)
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
index 262d78a4cf..25f25c4236 100644
--- a/ggml/src/CMakeLists.txt
+++ b/ggml/src/CMakeLists.txt
@@ -357,15 +357,29 @@ if (GGML_CPU_ALL_VARIANTS)
endif()
if (GGML_SYSTEM_ARCH STREQUAL "x86")
ggml_add_cpu_backend_variant(x64)
- ggml_add_cpu_backend_variant(sse42 SSE42)
- ggml_add_cpu_backend_variant(sandybridge SSE42 AVX)
- ggml_add_cpu_backend_variant(haswell SSE42 AVX F16C AVX2 BMI2 FMA)
- ggml_add_cpu_backend_variant(skylakex SSE42 AVX F16C AVX2 BMI2 FMA AVX512)
- ggml_add_cpu_backend_variant(icelake SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
- ggml_add_cpu_backend_variant(alderlake SSE42 AVX F16C AVX2 BMI2 FMA AVX_VNNI)
+ ggml_add_cpu_backend_variant(sse42 SSE42)
+ ggml_add_cpu_backend_variant(sandybridge SSE42 AVX)
+ if (NOT MSVC)
+ # __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
+ ggml_add_cpu_backend_variant(ivybridge SSE42 AVX F16C)
+ ggml_add_cpu_backend_variant(piledriver SSE42 AVX F16C FMA)
+ endif()
+ ggml_add_cpu_backend_variant(haswell SSE42 AVX F16C FMA AVX2 BMI2)
+ ggml_add_cpu_backend_variant(skylakex SSE42 AVX F16C FMA AVX2 BMI2 AVX512)
+ ggml_add_cpu_backend_variant(cannonlake SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI)
+ ggml_add_cpu_backend_variant(cascadelake SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VNNI)
+ ggml_add_cpu_backend_variant(icelake SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI AVX512_VNNI)
+ if (NOT MSVC)
+ # MSVC 2022 doesn't support BF16 intrinsics without `/arch:AVX10.1` ?!
+ # https://learn.microsoft.com/en-us/cpp/intrinsics/x64-amd64-intrinsics-list?view=msvc-170
+ # https://learn.microsoft.com/en-us/cpp/build/reference/arch-x64?view=msvc-170
+ ggml_add_cpu_backend_variant(cooperlake SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VNNI AVX512_BF16)
+ ggml_add_cpu_backend_variant(zen4 SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16)
+ endif()
+ ggml_add_cpu_backend_variant(alderlake SSE42 AVX F16C FMA AVX2 BMI2 AVX_VNNI)
if (NOT MSVC)
# MSVC doesn't support AMX
- ggml_add_cpu_backend_variant(sapphirerapids SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
+ ggml_add_cpu_backend_variant(sapphirerapids SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
endif()
elseif(GGML_SYSTEM_ARCH STREQUAL "ARM")
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
diff --git a/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ggml/src/ggml-cpu/ggml-cpu-impl.h
index 7597377cc2..0e8dd0ae05 100644
--- a/ggml/src/ggml-cpu/ggml-cpu-impl.h
+++ b/ggml/src/ggml-cpu/ggml-cpu-impl.h
@@ -328,7 +328,7 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b)
#if defined(_MSC_VER) || defined(__MINGW32__)
#include
-#elif defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__)
+#elif defined(__SSE__) || defined(__SSE3__) || defined(__SSSE3__) || defined(__AVX__) || defined(__F16C__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX512BF16__)
#include
#endif
diff --git a/ggml/src/ggml-cpu/simd-mappings.h b/ggml/src/ggml-cpu/simd-mappings.h
index 101a9c086b..a7a8272205 100644
--- a/ggml/src/ggml-cpu/simd-mappings.h
+++ b/ggml/src/ggml-cpu/simd-mappings.h
@@ -14,10 +14,6 @@
#include
#endif
-#if defined(__F16C__)
-#include
-#endif
-
#if defined(__riscv_v_intrinsic)
#include
#endif
diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt
index c0f8bcaa37..3b438c30ce 100644
--- a/ggml/src/ggml-cuda/CMakeLists.txt
+++ b/ggml/src/ggml-cuda/CMakeLists.txt
@@ -61,7 +61,7 @@ if (CUDAToolkit_FOUND)
set(CMAKE_CUDA_ARCHITECTURES ${PROCESSED_ARCHITECTURES})
else()
foreach(ARCH ${CMAKE_CUDA_ARCHITECTURES})
- if(ARCH MATCHES "^12[0-9]$")
+ if(ARCH MATCHES "^12[0-9](-real|-virtual)?$")
message(FATAL_ERROR "Compute capability ${ARCH} used, use ${ARCH}a or ${ARCH}f for Blackwell specific optimizations")
endif()
endforeach()
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 55fa2e6a7c..40ffe92c57 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -2211,7 +2211,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
const int cc = ggml_cuda_info().devices[id].cc;
const int warp_size = ggml_cuda_info().devices[id].warp_size;
- use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
+ use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0);
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false);
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]);
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
@@ -2219,7 +2219,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
} else {
const int cc = ggml_cuda_info().devices[ctx.device].cc;
const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
- use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
+ use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0);
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false);
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]);
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
@@ -2287,7 +2287,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
return;
}
- if (ggml_cuda_should_use_mmq(src0->type, cc, ne12)) {
+ if (ggml_cuda_should_use_mmq(src0->type, cc, ne12, /*n_experts=*/ne02)) {
ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst);
return;
}
diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu
index 6156dcdae7..85692d4543 100644
--- a/ggml/src/ggml-cuda/mmq.cu
+++ b/ggml/src/ggml-cuda/mmq.cu
@@ -259,7 +259,7 @@ void ggml_cuda_op_mul_mat_q(
GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_padded_row_size);
}
-bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
+bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts) {
#ifdef GGML_CUDA_FORCE_CUBLAS
return false;
#endif // GGML_CUDA_FORCE_CUBLAS
@@ -320,7 +320,10 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
if (GGML_CUDA_CC_IS_CDNA3(cc)) {
return true;
}
- if (ne11 <= 128 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) {
+ if (n_experts > 64 || ne11 <= 128) {
+ return true;
+ }
+ if (type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) {
return true;
}
if (ne11 <= 256 && (type == GGML_TYPE_Q4_K || type == GGML_TYPE_Q5_K)) {
diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh
index 63451ffab7..a382e6a697 100644
--- a/ggml/src/ggml-cuda/mmq.cuh
+++ b/ggml/src/ggml-cuda/mmq.cuh
@@ -4082,4 +4082,4 @@ void ggml_cuda_op_mul_mat_q(
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream);
-bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11);
+bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts);
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
index fe57d4c582..80e0fd2ff8 100644
--- a/ggml/src/ggml-impl.h
+++ b/ggml/src/ggml-impl.h
@@ -24,10 +24,6 @@
#include
#endif
-#if defined(__F16C__)
-#include
-#endif
-
#ifdef __cplusplus
extern "C" {
#endif
diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp
index 639715537b..353f6a4b46 100644
--- a/ggml/src/ggml-opencl/ggml-opencl.cpp
+++ b/ggml/src/ggml-opencl/ggml-opencl.cpp
@@ -263,6 +263,32 @@ static ggml_cl_compiler_version get_adreno_cl_compiler_version(const char *drive
return { type, major, minor, patch };
}
+// cl buffer wrapper
+struct ggml_cl_buffer {
+ cl_mem buffer;
+ size_t size;
+
+ ggml_cl_buffer()
+ : buffer(nullptr), size(0) {}
+
+ ~ggml_cl_buffer() {
+ if (buffer) {
+ CL_CHECK(clReleaseMemObject(buffer));
+ }
+ }
+
+ void allocate(cl_context context, size_t new_size) {
+ if (new_size > size) {
+ size = new_size;
+ if (buffer) {
+ CL_CHECK(clReleaseMemObject(buffer));
+ }
+ cl_int err;
+ CL_CHECK((buffer = clCreateBuffer(context, CL_MEM_READ_WRITE, size, NULL, &err), err));
+ }
+ }
+};
+
// Profiling
struct ProfilingInfo {
std::string op_name;
@@ -376,6 +402,11 @@ struct ggml_backend_opencl_context {
cl_context context;
cl_command_queue queue;
+ // prealloc buffers for transposing weights and activations
+ ggml_cl_buffer prealloc_quant_trans;
+ ggml_cl_buffer prealloc_scales_trans;
+ ggml_cl_buffer prealloc_act_trans;
+
cl_program program_add;
cl_program program_add_id;
cl_program program_clamp;
@@ -638,10 +669,6 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_transpose_16_buf;
cl_kernel kernel_transpose_16_4x1;
- cl_mem A_s_d_max; // max scale buffer size for transpose
- cl_mem A_q_d_max; // max weight buffer size for transpose
- cl_mem B_d_max; // max activation buffer size for transpose
-
// Gemm and Gemv related programs, kernels, etc
cl_program program_CL_gemm;
cl_program program_CL_gemv_general;
@@ -2600,9 +2627,9 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
required_B_d_bytes, max_B_d_bytes);
}
- CL_CHECK((backend_ctx->A_q_d_max = clCreateBuffer(context, 0, max_A_q_d_bytes, NULL, &err), err));
- CL_CHECK((backend_ctx->A_s_d_max = clCreateBuffer(context, 0, max_A_s_d_bytes, NULL, &err), err));
- CL_CHECK((backend_ctx->B_d_max = clCreateBuffer(context, 0, max_B_d_bytes, NULL, &err), err));
+ backend_ctx->prealloc_quant_trans.allocate(context, max_A_q_d_bytes);
+ backend_ctx->prealloc_scales_trans.allocate(context, max_A_s_d_bytes);
+ backend_ctx->prealloc_act_trans.allocate(context, max_B_d_bytes);
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
backend_ctx->disable_fusion = getenv("GGML_OPENCL_DISABLE_FUSION") != nullptr;
@@ -3607,32 +3634,35 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
// use sub_buffer of max buffer size instead
size_t q_size_bytes = K * M / 8 * sizeof(float);
+ backend_ctx->prealloc_quant_trans.allocate(context, q_size_bytes);
+
cl_buffer_region region;
region.origin = 0;
region.size = q_size_bytes;
cl_mem qT_d = clCreateSubBuffer(
- backend_ctx->A_q_d_max,
+ backend_ctx->prealloc_quant_trans.buffer,
0,
CL_BUFFER_CREATE_TYPE_REGION,
®ion,
&err);
- // cl_mem qT_d = clCreateBuffer(context, CL_MEM_READ_WRITE, q_size_bytes, NULL, &err);
CL_CHECK(err);
bool K_tile_trans = true;
if ((K / 32) % 4 != 0){
K_tile_trans =false;
}
+
size_t d_size_bytes = M * (K / 32) * 2;
+ backend_ctx->prealloc_scales_trans.allocate(context, d_size_bytes);
+
region.origin = 0;
region.size = d_size_bytes;
cl_mem dT_d = clCreateSubBuffer(
- backend_ctx->A_s_d_max,
+ backend_ctx->prealloc_scales_trans.buffer,
0,
CL_BUFFER_CREATE_TYPE_REGION,
®ion,
&err);
- // cl_mem dT_d = clCreateBuffer(context, CL_MEM_READ_WRITE, d_size_bytes, NULL, &err);
CL_CHECK(err);
// <----------------------------------------------------------------------------------> //
@@ -7395,8 +7425,10 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
region.origin = 0;
// Specify the size of the sub-buffer (divide by 2 for FP16)
region.size = K * (N + padding) * sizeof(float)/2;
+ backend_ctx->prealloc_act_trans.allocate(context, region.size);
+
B_d = clCreateSubBuffer(
- backend_ctx->B_d_max,
+ backend_ctx->prealloc_act_trans.buffer,
0,
CL_BUFFER_CREATE_TYPE_REGION,
®ion,
diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp
index e7890a5ee9..164b39d01e 100644
--- a/ggml/src/ggml-rpc/ggml-rpc.cpp
+++ b/ggml/src/ggml-rpc/ggml-rpc.cpp
@@ -524,6 +524,7 @@ static std::shared_ptr get_socket(const std::string & endpoint) {
std::string host;
int port;
if (!parse_endpoint(endpoint, host, port)) {
+ GGML_LOG_ERROR("Failed to parse endpoint: %s\n", endpoint.c_str());
return nullptr;
}
#ifdef _WIN32
@@ -2053,6 +2054,10 @@ ggml_backend_reg_t ggml_backend_rpc_reg(void) {
static uint32_t ggml_backend_rpc_get_device_count(const char * endpoint) {
auto sock = get_socket(endpoint);
+ if (sock == nullptr) {
+ GGML_LOG_ERROR("Failed to connect to %s\n", endpoint);
+ return 0;
+ }
rpc_msg_device_count_rsp response;
bool status = send_rpc_cmd(sock, RPC_CMD_DEVICE_COUNT, nullptr, 0, &response, sizeof(response));
RPC_STATUS_ASSERT(status);
diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py
index 27578daaf9..616b8add36 100644
--- a/gguf-py/gguf/constants.py
+++ b/gguf-py/gguf/constants.py
@@ -377,6 +377,7 @@ class MODEL_ARCH(IntEnum):
PHIMOE = auto()
PLAMO = auto()
PLAMO2 = auto()
+ PLAMO3 = auto()
CODESHELL = auto()
ORION = auto()
INTERNLM2 = auto()
@@ -773,6 +774,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.PHIMOE: "phimoe",
MODEL_ARCH.PLAMO: "plamo",
MODEL_ARCH.PLAMO2: "plamo2",
+ MODEL_ARCH.PLAMO3: "plamo3",
MODEL_ARCH.CODESHELL: "codeshell",
MODEL_ARCH.ORION: "orion",
MODEL_ARCH.INTERNLM2: "internlm2",
@@ -1763,6 +1765,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.SSM_B_NORM,
MODEL_TENSOR.SSM_C_NORM,
],
+ MODEL_ARCH.PLAMO3: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_QKV,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_POST_NORM,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_POST_NORM,
+ ],
MODEL_ARCH.GPT2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.POS_EMBD,
diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py
index 1690d991f2..115df6c7c3 100644
--- a/gguf-py/gguf/tensor_mapping.py
+++ b/gguf-py/gguf/tensor_mapping.py
@@ -595,6 +595,7 @@ class TensorNameMap:
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
"transformer.layers.{bid}.attn.q_norm", # openelm
"model.layers.layers.{bid}.mixer.q", # plamo2
+ "model.layers.layers.{bid}.mixer.q_norm", # plamo3
"layers.{bid}.self_attn.q_norm", # qwen3-embedding
"model.layers.{bid}.attention.query_layernorm", # apertus
),
@@ -610,6 +611,7 @@ class TensorNameMap:
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
"transformer.layers.{bid}.attn.k_norm", # openelm
"model.layers.layers.{bid}.mixer.k", # plamo2
+ "model.layers.layers.{bid}.mixer.k_norm", # plamo3
"layers.{bid}.self_attn.k_norm", # qwen3-embedding
"model.layers.{bid}.attention.key_layernorm", # apertus
),
diff --git a/include/llama.h b/include/llama.h
index 5e8974c94f..1b3cc16f1a 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -286,7 +286,7 @@ extern "C" {
// NULL-terminated list of buffer types to use for tensors that match a pattern
const struct llama_model_tensor_buft_override * tensor_buft_overrides;
- int32_t n_gpu_layers; // number of layers to store in VRAM
+ int32_t n_gpu_layers; // number of layers to store in VRAM, a negative value means all layers
enum llama_split_mode split_mode; // how to split the model across multiple GPUs
// the GPU that is used for the entire model when split_mode is LLAMA_SPLIT_MODE_NONE
@@ -467,10 +467,17 @@ extern "C" {
// Frees all allocated memory
LLAMA_API void llama_free(struct llama_context * ctx);
+ enum llama_params_fit_status {
+ LLAMA_PARAMS_FIT_STATUS_SUCCESS = 0, // found allocations that are projected to fit
+ LLAMA_PARAMS_FIT_STATUS_FAILURE = 1, // could not find allocations that are projected to fit
+ LLAMA_PARAMS_FIT_STATUS_ERROR = 2, // a hard error occured, e.g. because no model could be found at the specified path
+ };
+
// fits mparams and cparams to free device memory (assumes system memory is unlimited)
- // returns true if the parameters could be successfully modified to fit device memory
- // this function is NOT thread safe because it modifies the global llama logger state
- LLAMA_API bool llama_params_fit(
+ // - returns true if the parameters could be successfully modified to fit device memory
+ // - this function is NOT thread safe because it modifies the global llama logger state
+ // - only parameters that have the same value as in llama_default_model_params are modified
+ LLAMA_API enum llama_params_fit_status llama_params_fit(
const char * path_model,
struct llama_model_params * mparams,
struct llama_context_params * cparams,
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 1e155534bd..762ea65c71 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -107,6 +107,7 @@ add_library(llama
models/phi3.cpp
models/plamo.cpp
models/plamo2.cpp
+ models/plamo3.cpp
models/plm.cpp
models/qwen.cpp
models/qwen2.cpp
diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
index 75013d8d33..94a6807eac 100644
--- a/src/llama-arch.cpp
+++ b/src/llama-arch.cpp
@@ -42,6 +42,7 @@ static const std::map LLM_ARCH_NAMES = {
{ LLM_ARCH_PHIMOE, "phimoe" },
{ LLM_ARCH_PLAMO, "plamo" },
{ LLM_ARCH_PLAMO2, "plamo2" },
+ { LLM_ARCH_PLAMO3, "plamo3" },
{ LLM_ARCH_CODESHELL, "codeshell" },
{ LLM_ARCH_ORION, "orion" },
{ LLM_ARCH_INTERNLM2, "internlm2" },
@@ -1077,6 +1078,22 @@ static std::set llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_ATTN_POST_NORM,
LLM_TENSOR_FFN_POST_NORM,
};
+ case LLM_ARCH_PLAMO3:
+ return {
+ LLM_TENSOR_TOKEN_EMBD,
+ LLM_TENSOR_OUTPUT_NORM,
+ LLM_TENSOR_OUTPUT,
+ LLM_TENSOR_ATTN_NORM,
+ LLM_TENSOR_ATTN_QKV,
+ LLM_TENSOR_ATTN_Q_NORM,
+ LLM_TENSOR_ATTN_K_NORM,
+ LLM_TENSOR_ATTN_OUT,
+ LLM_TENSOR_ATTN_POST_NORM,
+ LLM_TENSOR_FFN_NORM,
+ LLM_TENSOR_FFN_POST_NORM,
+ LLM_TENSOR_FFN_DOWN,
+ LLM_TENSOR_FFN_UP,
+ };
case LLM_ARCH_CODESHELL:
return {
LLM_TENSOR_TOKEN_EMBD,
diff --git a/src/llama-arch.h b/src/llama-arch.h
index 27bdedc83c..714ead4025 100644
--- a/src/llama-arch.h
+++ b/src/llama-arch.h
@@ -46,6 +46,7 @@ enum llm_arch {
LLM_ARCH_PHIMOE,
LLM_ARCH_PLAMO,
LLM_ARCH_PLAMO2,
+ LLM_ARCH_PLAMO3,
LLM_ARCH_CODESHELL,
LLM_ARCH_ORION,
LLM_ARCH_INTERNLM2,
diff --git a/src/llama-context.cpp b/src/llama-context.cpp
index 015ebae71d..1c530fdc91 100644
--- a/src/llama-context.cpp
+++ b/src/llama-context.cpp
@@ -294,8 +294,8 @@ llama_context::llama_context(
// enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
bool pipeline_parallel =
model.n_devices() > 1 &&
- model.params.n_gpu_layers > (int) model.hparams.n_layer &&
- model.params.split_mode == LLAMA_SPLIT_MODE_LAYER &&
+ model.n_gpu_layers() > model.hparams.n_layer &&
+ model.split_mode() == LLAMA_SPLIT_MODE_LAYER &&
cparams.offload_kqv &&
!model.has_tensor_overrides();
@@ -1570,7 +1570,7 @@ llm_graph_cb llama_context::graph_get_cb() const {
// norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
// FIXME: fix in ggml_backend_sched
- const bool full_offload = model.params.n_gpu_layers > (int) model.hparams.n_layer;
+ const bool full_offload = model.n_gpu_layers() > model.hparams.n_layer;
if (ubatch.n_tokens < 32 || full_offload) {
if (il != -1 && strcmp(name, "norm") == 0) {
const auto & dev_layer = model.dev_layer(il);
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index 69075742c9..5e664c8c57 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -1227,6 +1227,26 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false);
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
} break;
+ case LLM_ARCH_PLAMO3:
+ {
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+ const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
+ if (found_swa && hparams.n_swa > 0) {
+ uint32_t swa_period = 8;
+ hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
+ hparams.rope_freq_scale_train_swa = 1.0f;
+ ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa);
+ ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false);
+ hparams.set_swa_pattern(swa_period);
+ } else {
+ hparams.swa_type = LLAMA_SWA_TYPE_NONE;
+ }
+
+ switch (hparams.n_layer) {
+ case 24: type = LLM_TYPE_2B; break;
+ default: type = LLM_TYPE_UNKNOWN;
+ }
+ } break;
case LLM_ARCH_GPT2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -2378,11 +2398,11 @@ void llama_model::load_vocab(llama_model_loader & ml) {
bool llama_model::load_tensors(llama_model_loader & ml) {
const auto & split_mode = params.split_mode;
- const auto & n_gpu_layers = params.n_gpu_layers;
const auto & use_mlock = params.use_mlock;
const auto & tensor_split = params.tensor_split;
- const int n_layer = hparams.n_layer;
+ const int n_layer = hparams.n_layer;
+ const int n_gpu_layers = this->n_gpu_layers();
const bool use_mmap_buffer = true;
@@ -3828,6 +3848,44 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, i), {n_embd}, 0);
}
} break;
+ case LLM_ARCH_PLAMO3:
+ {
+ const int64_t head_dim_q = hparams.n_embd_head_k;
+ const int64_t head_dim_v = hparams.n_embd_head_v;
+
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+ if (output == NULL) {
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+ }
+
+ for (int i = 0; i < n_layer; ++i) {
+ auto & layer = layers[i];
+
+ const int64_t num_attention_heads = hparams.n_head(i);
+ const int64_t num_key_value_heads = hparams.n_head_kv(i);
+ const int64_t q_proj_dim = num_attention_heads * head_dim_q;
+ const int64_t k_proj_dim = num_key_value_heads * head_dim_q;
+ const int64_t v_proj_dim = num_key_value_heads * head_dim_v;
+ const int64_t n_ff_cur = hparams.n_ff(i);
+
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+ layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i),
+ {n_embd,q_proj_dim + k_proj_dim + v_proj_dim}, 0);
+ layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {head_dim_q}, 0);
+ layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {head_dim_q}, 0);
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {num_attention_heads * head_dim_v, n_embd}, 0);
+ layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, i), {n_embd}, 0);
+
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+ layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, i), {n_embd}, 0);
+
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff_cur * 2}, 0);
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff_cur, n_embd}, 0);
+ }
+ } break;
case LLM_ARCH_GPT2:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -6884,6 +6942,14 @@ size_t llama_model::n_devices() const {
return devices.size();
}
+uint32_t llama_model::n_gpu_layers() const {
+ return params.n_gpu_layers >= 0 ? params.n_gpu_layers : hparams.n_layer + 1;
+}
+
+llama_split_mode llama_model::split_mode() const {
+ return params.split_mode;
+}
+
std::map llama_model::memory_breakdown() const {
std::map ret;
for (const auto & [ctx, bufs] : pimpl->ctxs_bufs) {
@@ -7465,6 +7531,14 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{
llm = std::make_unique(*this, params);
} break;
+ case LLM_ARCH_PLAMO3:
+ {
+ if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
+ llm = std::make_unique> (*this, params);
+ } else {
+ llm = std::make_unique>(*this, params);
+ }
+ } break;
case LLM_ARCH_GPT2:
{
llm = std::make_unique(*this, params);
@@ -7794,7 +7868,7 @@ llama_model_params llama_model_default_params() {
llama_model_params result = {
/*.devices =*/ nullptr,
/*.tensor_buft_overrides =*/ nullptr,
- /*.n_gpu_layers =*/ 999,
+ /*.n_gpu_layers =*/ -1,
/*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER,
/*.main_gpu =*/ 0,
/*.tensor_split =*/ nullptr,
@@ -7969,6 +8043,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_PHIMOE:
case LLM_ARCH_PLAMO:
case LLM_ARCH_PLAMO2:
+ case LLM_ARCH_PLAMO3:
case LLM_ARCH_GEMMA:
case LLM_ARCH_GEMMA2:
case LLM_ARCH_GEMMA3:
diff --git a/src/llama-model.h b/src/llama-model.h
index 9c00eec75f..dbe5edc153 100644
--- a/src/llama-model.h
+++ b/src/llama-model.h
@@ -466,8 +466,6 @@ struct llama_model {
struct ggml_tensor * dense_2_out_layers = nullptr;
struct ggml_tensor * dense_3_out_layers = nullptr;
- llama_model_params params;
-
// gguf metadata
std::unordered_map gguf_kv;
@@ -498,6 +496,9 @@ struct llama_model {
size_t n_tensors() const;
size_t n_devices() const;
+ uint32_t n_gpu_layers() const;
+ llama_split_mode split_mode() const;
+
std::map memory_breakdown() const;
// total number of parameters in the model
@@ -526,6 +527,8 @@ struct llama_model {
ggml_cgraph * build_graph(const llm_graph_params & params) const;
private:
+ llama_model_params params;
+
struct impl;
std::unique_ptr pimpl;
};
diff --git a/src/llama.cpp b/src/llama.cpp
index 1e18637e36..76b3acbadb 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -140,6 +140,10 @@ enum layer_fraction_t {
};
// this enum is only used in llama_params_fit_impl but needs to be defined outside of it to fix a Windows compilation issue
+class llama_params_fit_exception : public std::runtime_error {
+ using std::runtime_error::runtime_error;
+};
+
static void llama_params_fit_impl(
const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams,
float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides,
@@ -181,12 +185,11 @@ static void llama_params_fit_impl(
}
}
- int64_t sum_total = 0;
+ int64_t sum_free = 0;
int64_t sum_projected_free = 0;
int64_t min_projected_free = INT64_MAX;
int64_t sum_projected_used = 0;
int64_t sum_projected_model = 0;
- int64_t sum_projected_ctx = 0;
if (nd > 1) {
LLAMA_LOG_INFO("%s: projected memory use with initial parameters [MiB]:\n", __func__);
@@ -197,12 +200,11 @@ static void llama_params_fit_impl(
const int64_t projected_used = dmd.mb.total();
const int64_t projected_free = dmd.free - projected_used;
- sum_total += dmd.total;
+ sum_free += dmd.free;
sum_projected_used += projected_used;
sum_projected_free += projected_free;
min_projected_free = std::min(min_projected_free, projected_free);
sum_projected_model += dmd.mb.model;
- sum_projected_ctx += dmd.mb.context;
if (nd > 1) {
LLAMA_LOG_INFO("%s: - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " %s\n",
@@ -210,10 +212,9 @@ static void llama_params_fit_impl(
projected_free >= 0 ? "surplus" : "deficit");
}
}
- assert(sum_total >= 0 && sum_projected_used >= 0 && sum_projected_ctx >= 0);
- assert(sum_projected_used >= sum_projected_ctx);
+ assert(sum_free >= 0 && sum_projected_used >= 0);
LLAMA_LOG_INFO("%s: projected to use %" PRId64 " MiB of device memory vs. %" PRId64 " MiB of free device memory\n",
- __func__, sum_projected_used/MiB, sum_total/MiB);
+ __func__, sum_projected_used/MiB, sum_free/MiB);
if (min_projected_free >= margin) {
if (nd == 1) {
LLAMA_LOG_INFO("%s: will leave %" PRId64 " >= %" PRId64 " MiB of free device memory, no changes needed\n",
@@ -236,9 +237,7 @@ static void llama_params_fit_impl(
__func__, margin/MiB, -global_surplus/MiB);
if (cparams->n_ctx == 0) {
if (hp_nct > n_ctx_min) {
- const int64_t bytes_per_ctx = sum_projected_ctx / hp_nct;
-
- int64_t memory_reduction = -global_surplus;
+ int64_t sum_used_target = sum_free - nd*margin_s;
if (nd > 1) {
// for multiple devices we need to be more conservative in terms of how much context we think can fit:
// - for dense models only whole layers can be assigned to devices
@@ -246,24 +245,34 @@ static void llama_params_fit_impl(
// - on average we expect a waste of 0.5 layers/tensors per device
// - use slightly more than the expected average for nd devices to be safe
const int64_t model_per_layer = sum_projected_model / std::min(uint32_t(mparams->n_gpu_layers), hp_ngl);
- memory_reduction += (nd + 1) * model_per_layer / (hp_nex == 0 ? 2 : 6);
+ sum_used_target -= (nd + 1) * model_per_layer / (hp_nex == 0 ? 2 : 6);
}
- uint32_t ctx_reduction = std::min(uint32_t((memory_reduction + bytes_per_ctx - 1) / bytes_per_ctx), hp_nct - n_ctx_min);
- cparams->n_ctx = hp_nct - ctx_reduction;
- cparams->n_ctx = std::max(cparams->n_ctx - cparams->n_ctx % 256, n_ctx_min); // round down context for CUDA backend
+ int64_t sum_projected_used_min_ctx = 0;
+ cparams->n_ctx = n_ctx_min;
+ const dmds_t dmds_min_ctx = llama_get_device_memory_data(path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level);
+ for (const auto & dmd : dmds_min_ctx) {
+ sum_projected_used_min_ctx += dmd.mb.total();
+ }
+ if (sum_used_target > sum_projected_used_min_ctx) {
+ // linear interpolation between minimum and maximum context size:
+ cparams->n_ctx += (hp_nct - n_ctx_min) * (sum_used_target - sum_projected_used_min_ctx)
+ / (sum_projected_used - sum_projected_used_min_ctx);
+ cparams->n_ctx = std::max(cparams->n_ctx - cparams->n_ctx % 256, n_ctx_min); // round down context for CUDA backend
- ctx_reduction = hp_nct - cparams->n_ctx;
- memory_reduction = ctx_reduction * bytes_per_ctx;
- global_surplus += memory_reduction;
- LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n",
- __func__, hp_nct, cparams->n_ctx, memory_reduction/MiB);
- if (global_surplus >= 0) {
+ const int64_t bytes_per_ctx = (sum_projected_used - sum_projected_used_min_ctx) / (hp_nct - n_ctx_min);
+ const int64_t memory_reduction = (hp_nct - cparams->n_ctx) * bytes_per_ctx;
+ LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n",
+ __func__, hp_nct, cparams->n_ctx, memory_reduction/MiB);
if (nd == 1) {
LLAMA_LOG_INFO("%s: entire model can be fit by reducing context\n", __func__);
return;
}
LLAMA_LOG_INFO("%s: entire model should be fit across devices by reducing context\n", __func__);
+ } else {
+ const int64_t memory_reduction = sum_projected_used - sum_projected_used_min_ctx;
+ LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n",
+ __func__, hp_nct, cparams->n_ctx, memory_reduction/MiB);
}
} else {
LLAMA_LOG_INFO("%s: default model context size is %" PRIu32 " which is <= the min. context size of %" PRIu32 " -> no change\n",
@@ -276,28 +285,28 @@ static void llama_params_fit_impl(
}
if (mparams->n_gpu_layers != default_mparams.n_gpu_layers) {
- throw std::runtime_error("n_gpu_layers already set by user to " + std::to_string(mparams->n_gpu_layers) + ", abort");
+ throw llama_params_fit_exception("n_gpu_layers already set by user to " + std::to_string(mparams->n_gpu_layers) + ", abort");
}
if (nd > 1) {
if (!tensor_split) {
- throw std::runtime_error("did not provide a buffer to write the tensor_split to, abort");
+ throw llama_params_fit_exception("did not provide a buffer to write the tensor_split to, abort");
}
if (mparams->tensor_split) {
for (size_t id = 0; id < nd; id++) {
if (mparams->tensor_split[id] != 0.0f) {
- throw std::runtime_error("model_params::tensor_split already set by user, abort");
+ throw llama_params_fit_exception("model_params::tensor_split already set by user, abort");
}
}
}
if (mparams->split_mode == LLAMA_SPLIT_MODE_ROW) {
- throw std::runtime_error("changing weight allocation for LLAMA_SPLIT_MODE_ROW not implemented, abort");
+ throw llama_params_fit_exception("changing weight allocation for LLAMA_SPLIT_MODE_ROW not implemented, abort");
}
}
if (!tensor_buft_overrides) {
- throw std::runtime_error("did not provide buffer to set tensor_buft_overrides, abort");
+ throw llama_params_fit_exception("did not provide buffer to set tensor_buft_overrides, abort");
}
if (mparams->tensor_buft_overrides && (mparams->tensor_buft_overrides->pattern || mparams->tensor_buft_overrides->buft)) {
- throw std::runtime_error("model_params::tensor_buft_overrides already set by user, abort");
+ throw llama_params_fit_exception("model_params::tensor_buft_overrides already set by user, abort");
}
// step 3: iteratively fill the back to front with "dense" layers
@@ -380,8 +389,8 @@ static void llama_params_fit_impl(
tensor_buft_overrides[itbo].buft = nullptr;
itbo++;
mparams.tensor_buft_overrides = tensor_buft_overrides;
- throw std::runtime_error("llama_params_fit_n_tensor_buft_overrides() == "
- + std::to_string(ntbo) + " is insufficient for model\n");
+ throw llama_params_fit_exception("llama_max_tensor_buft_overrides() == "
+ + std::to_string(ntbo) + " is insufficient for model");
}
tensor_buft_overrides[itbo].pattern = get_overflow_pattern(il, il == il0 ? ngl_per_device[id].overflow_type : LAYER_FRACTION_MOE);
tensor_buft_overrides[itbo].buft = overflow_bufts[id];
@@ -503,6 +512,9 @@ static void llama_params_fit_impl(
if (mem_high[id] > targets[id]) {
assert(ngl_per_device_high[id].n_layer > ngl_per_device[id].n_layer);
uint32_t delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer;
+ if (hp_nex > 0 && size_t(id) == nd - 1) {
+ delta--;
+ }
LLAMA_LOG_DEBUG("%s: start filling device %" PRIu32 ", delta=%" PRIu32 "\n", __func__, id, delta);
while (delta > 1) {
uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]);
@@ -638,7 +650,7 @@ static void llama_params_fit_impl(
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_UP;
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_UP\n", __func__);
std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
- if (mem_test[id] < targets[id]) {
+ if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) {
ngl_per_device = ngl_per_device_test;
mem = mem_test;
id_dense_start = id_dense_start_test;
@@ -648,7 +660,7 @@ static void llama_params_fit_impl(
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_GATE;
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_GATE\n", __func__);
mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
- if (mem_test[id] < targets[id]) {
+ if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) {
ngl_per_device = ngl_per_device_test;
mem = mem_test;
id_dense_start = id_dense_start_test;
@@ -659,7 +671,7 @@ static void llama_params_fit_impl(
ngl_per_device_test[id].overflow_type = LAYER_FRACTION_ATTN;
LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_ATTN\n", __func__);
mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts);
- if (mem_test[id] < targets[id]) {
+ if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) {
ngl_per_device = ngl_per_device_test;
mem = mem_test;
id_dense_start = id_dense_start_test;
@@ -678,22 +690,25 @@ static void llama_params_fit_impl(
set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams);
}
-bool llama_params_fit(
+enum llama_params_fit_status llama_params_fit(
const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams,
float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides,
size_t margin_s, uint32_t n_ctx_min, enum ggml_log_level log_level) {
const int64_t t0_us = llama_time_us();
- bool ok = true;
+ llama_params_fit_status status = LLAMA_PARAMS_FIT_STATUS_SUCCESS;
try {
llama_params_fit_impl(path_model, mparams, cparams, tensor_split, tensor_buft_overrides, margin_s, n_ctx_min, log_level);
LLAMA_LOG_INFO("%s: successfully fit params to free device memory\n", __func__);
- } catch (const std::runtime_error & e) {
+ } catch (const llama_params_fit_exception & e) {
LLAMA_LOG_WARN("%s: failed to fit params to free device memory: %s\n", __func__, e.what());
- ok = false;
+ status = LLAMA_PARAMS_FIT_STATUS_FAILURE;
+ } catch (const std::runtime_error & e) {
+ LLAMA_LOG_ERROR("%s: encountered an error while trying to fit params to free device memory: %s\n", __func__, e.what());
+ status = LLAMA_PARAMS_FIT_STATUS_ERROR;
}
const int64_t t1_us = llama_time_us();
LLAMA_LOG_INFO("%s: fitting params to free memory took %.2f seconds\n", __func__, (t1_us - t0_us) * 1e-6);
- return ok;
+ return status;
}
struct llama_sampler_chain_params llama_sampler_chain_default_params() {
diff --git a/src/models/models.h b/src/models/models.h
index dd0e286eda..e2cd4e484f 100644
--- a/src/models/models.h
+++ b/src/models/models.h
@@ -406,6 +406,11 @@ struct llm_build_plamo : public llm_graph_context {
llm_build_plamo(const llama_model & model, const llm_graph_params & params);
};
+template
+struct llm_build_plamo3 : public llm_graph_context {
+ llm_build_plamo3(const llama_model & model, const llm_graph_params & params);
+};
+
struct llm_build_plm : public llm_graph_context {
llm_build_plm(const llama_model & model, const llm_graph_params & params);
};
diff --git a/src/models/plamo3.cpp b/src/models/plamo3.cpp
new file mode 100644
index 0000000000..55c8064679
--- /dev/null
+++ b/src/models/plamo3.cpp
@@ -0,0 +1,128 @@
+#include "models.h"
+
+template
+llm_build_plamo3::llm_build_plamo3(const llama_model & model, const llm_graph_params & params) :
+ llm_graph_context(params) {
+ const int64_t head_dim_q = hparams.n_embd_head_k;
+ const int64_t head_dim_v = hparams.n_embd_head_v;
+
+ ggml_tensor * cur;
+ ggml_tensor * inpL = build_inp_embd(model.tok_embd);
+ ggml_tensor * inp_pos = build_inp_pos();
+
+ using inp_attn_type = std::conditional_t;
+ inp_attn_type * inp_attn = nullptr;
+
+ if constexpr (iswa) {
+ inp_attn = build_attn_inp_kv_iswa();
+ } else {
+ inp_attn = build_attn_inp_kv();
+ }
+
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+ for (int il = 0; il < n_layer; ++il) {
+ ggml_tensor * residual = inpL;
+
+ float freq_base_l = 0.0f;
+ float freq_scale_l = 0.0f;
+ if constexpr (iswa) {
+ freq_base_l = model.get_rope_freq_base (cparams, il);
+ freq_scale_l = model.get_rope_freq_scale(cparams, il);
+ } else {
+ freq_base_l = freq_base;
+ freq_scale_l = freq_scale;
+ }
+
+ cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
+ cb(cur, "attn_norm", il);
+
+ ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur);
+ cb(cur, "wqkv", il);
+
+ const int32_t n_head = hparams.n_head(il);
+ const int32_t n_head_kv = hparams.n_head_kv(il);
+
+ const int64_t q_offset = 0;
+ const int64_t k_offset = head_dim_q * n_head;
+ const int64_t v_offset = k_offset + head_dim_q * n_head_kv;
+
+ ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, head_dim_q, n_head, n_tokens,
+ head_dim_q * sizeof(float), qkv->nb[1], q_offset * ggml_element_size(qkv));
+ ggml_tensor * Kcur = ggml_view_3d(ctx0, qkv, head_dim_q, n_head_kv, n_tokens,
+ head_dim_q * sizeof(float), qkv->nb[1], k_offset * ggml_element_size(qkv));
+ ggml_tensor * Vcur = ggml_view_3d(ctx0, qkv, head_dim_v, n_head_kv, n_tokens,
+ head_dim_v * sizeof(float), qkv->nb[1], v_offset * ggml_element_size(qkv));
+
+ cb(Qcur, "Qcur", il);
+ cb(Kcur, "Kcur", il);
+ cb(Vcur, "Vcur", il);
+
+ Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
+ cb(Qcur, "attn_q_norm", il);
+ Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
+ cb(Kcur, "attn_k_norm", il);
+
+ Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+ ext_factor, attn_factor, beta_fast, beta_slow);
+ Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+ ext_factor, attn_factor, beta_fast, beta_slow);
+
+ const float attn_scale = 1.0f / sqrtf(float(head_dim_q));
+
+ cur = build_attn(inp_attn,
+ model.layers[il].wo, NULL,
+ Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, attn_scale, il);
+ cb(cur, "attn_out", il);
+
+ if (il == n_layer - 1 && inp_out_ids) {
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ residual = ggml_get_rows(ctx0, residual, inp_out_ids);
+ }
+
+ cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il);
+ cb(cur, "attn_post_norm", il);
+
+ cur = ggml_add(ctx0, cur, residual);
+ cb(cur, "attn_residual", il);
+
+ residual = cur;
+
+ cur = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
+ cb(cur, "ffn_norm", il);
+
+ cur = build_ffn(cur,
+ model.layers[il].ffn_up, NULL, NULL,
+ NULL, NULL, NULL,
+ model.layers[il].ffn_down, NULL, NULL,
+ NULL,
+ LLM_FFN_SWIGLU, LLM_FFN_SEQ, il);
+ cb(cur, "ffn_out", il);
+
+ cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, il);
+ cb(cur, "ffn_post_norm", il);
+
+ cur = ggml_add(ctx0, cur, residual);
+ cb(cur, "ffn_residual", il);
+
+ cur = build_cvec(cur, il);
+ cb(cur, "l_out", il);
+ inpL = cur;
+ }
+
+ cur = inpL;
+
+ cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
+ res->t_embd = cur;
+
+ cur = build_lora_mm(model.output, cur);
+ res->t_logits = cur;
+
+ ggml_build_forward_expand(gf, cur);
+}
+
+// Explicit template instantiations
+template struct llm_build_plamo3;
+template struct llm_build_plamo3;
diff --git a/tools/fit-params/fit-params.cpp b/tools/fit-params/fit-params.cpp
index de47763d3e..c7e7748ca9 100644
--- a/tools/fit-params/fit-params.cpp
+++ b/tools/fit-params/fit-params.cpp
@@ -26,10 +26,10 @@ int main(int argc, char ** argv) {
llama_numa_init(params.numa);
auto mparams = common_model_params_to_llama(params);
auto cparams = common_context_params_to_llama(params);
- const bool success = llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
+ const llama_params_fit_status status = llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target, params.fit_params_min_ctx,
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
- if (!success) {
+ if (status != LLAMA_PARAMS_FIT_STATUS_SUCCESS) {
LOG_ERR("%s: failed to fit CLI arguments to free memory, exiting...\n", __func__);
exit(1);
}
diff --git a/tools/mtmd/models/models.h b/tools/mtmd/models/models.h
index 8d6d4ef67b..e08c33f353 100644
--- a/tools/mtmd/models/models.h
+++ b/tools/mtmd/models/models.h
@@ -2,6 +2,11 @@
#include "../clip-graph.h"
+/*
+ * IMPORTANT: The mtmd module does NOT accept pull requests that are fully or predominantly AI-generated.
+ * We encourage human contributors to ensure the quality and reliability of the codebase.
+ */
+
struct clip_graph_siglip : clip_graph {
clip_graph_siglip(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
ggml_cgraph * build() override;
diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h
index 9f7e861e92..44d05ceaee 100644
--- a/tools/mtmd/mtmd.h
+++ b/tools/mtmd/mtmd.h
@@ -27,6 +27,9 @@
* - Make sure the C API is aligned with the libllama C API (as in llama.h)
* - Do not include model name (e.g., qwen, gemma) in the API, use generic terms instead
* - Keep the API minimal, do not expose internal details unless necessary
+ *
+ * IMPORTANT: The mtmd module does NOT accept pull requests that are fully or predominantly AI-generated.
+ * We encourage human contributors to ensure the quality and reliability of the codebase.
*/
#ifdef LLAMA_SHARED