Merge branch 'ggml-org:master' into Kimi-Linear

This commit is contained in:
ymcki 2026-01-09 14:09:56 +08:00 committed by GitHub
commit 6977ddbe85
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
77 changed files with 5765 additions and 4975 deletions

View File

@ -33,6 +33,7 @@ FROM ubuntu:$UBUNTU_VERSION AS base
RUN apt-get update \
&& apt-get install -y libgomp1 curl libvulkan1 mesa-vulkan-drivers \
libglvnd0 libgl1 libglx0 libegl1 libgles2 \
&& apt autoremove -y \
&& apt clean -y \
&& rm -rf /tmp/* /var/tmp/* \

View File

@ -152,13 +152,13 @@ jobs:
DAWN_VERSION="v2.0.0"
DAWN_OWNER="reeselevine"
DAWN_REPO="dawn"
DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release.zip"
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release"
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip"
curl -L -o artifact.zip \
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip"
mkdir dawn
unzip artifact.zip
tar -xvf Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release.tar.gz -C dawn --strip-components=1
tar -xvf ${DAWN_ASSET_NAME}.tar.gz -C dawn --strip-components=1
- name: Build
id: cmake_build
@ -532,13 +532,13 @@ jobs:
DAWN_VERSION="v2.0.0"
DAWN_OWNER="reeselevine"
DAWN_REPO="dawn"
DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release.zip"
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release"
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip"
curl -L -o artifact.zip \
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip"
mkdir dawn
unzip artifact.zip
tar -xvf Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release.tar.gz -C dawn --strip-components=1
tar -xvf ${DAWN_ASSET_NAME}.tar.gz -C dawn --strip-components=1
- name: Build
id: cmake_build
@ -1418,7 +1418,6 @@ jobs:
echo "FIXME: test on devices"
openEuler-latest-cmake-cann:
if: ${{ github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'Ascend NPU') }}
defaults:
run:
shell: bash -el {0}
@ -1705,6 +1704,34 @@ jobs:
run: |
GG_BUILD_METAL=1 bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp
ggml-ci-mac-webgpu:
runs-on: [self-hosted, macOS, ARM64]
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
- name: Dawn Dependency
id: dawn-depends
run: |
DAWN_VERSION="v2.0.0"
DAWN_OWNER="reeselevine"
DAWN_REPO="dawn"
DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release"
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip"
curl -L -o artifact.zip \
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip"
mkdir dawn
unzip artifact.zip
tar -xvf ${DAWN_ASSET_NAME}.tar.gz -C dawn --strip-components=1
- name: Test
id: ggml-ci
run: |
GG_BUILD_WEBGPU=1 GG_BUILD_WEBGPU_DAWN_PREFIX="$GITHUB_WORKSPACE/dawn" \
bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp
ggml-ci-mac-vulkan:
runs-on: [self-hosted, macOS, ARM64]

1
.gitignore vendored
View File

@ -130,6 +130,7 @@ poetry.toml
# Local scripts
/run-vim.sh
/run-chat.sh
/run-spec.sh
/.ccache/
# IDE

View File

@ -482,21 +482,6 @@ To learn more about model quantization, [read this documentation](tools/quantize
</details>
## [`llama-run`](tools/run)
#### A comprehensive example for running `llama.cpp` models. Useful for inferencing. Used with RamaLama [^3].
- <details>
<summary>Run a model with a specific prompt (by default it's pulled from Ollama registry)</summary>
```bash
llama-run granite-code
```
</details>
[^3]: [RamaLama](https://github.com/containers/ramalama)
## [`llama-simple`](examples/simple)
#### A minimal example for implementing apps with `llama.cpp`. Useful for developers.
@ -600,7 +585,6 @@ $ echo "source ~/.llama-completion.bash" >> ~/.bashrc
- [stb-image](https://github.com/nothings/stb) - Single-header image format decoder, used by multimodal subsystem - Public domain
- [nlohmann/json](https://github.com/nlohmann/json) - Single-header JSON library, used by various tools/examples - MIT License
- [minja](https://github.com/google/minja) - Minimal Jinja parser in C++, used by various tools/examples - MIT License
- [linenoise.cpp](./tools/run/linenoise.cpp/linenoise.cpp) - C++ library that provides readline-like line editing capabilities, used by `llama-run` - BSD 2-Clause License
- [curl](https://curl.se/) - Client-side URL transfer library, used by various tools/examples - [CURL License](https://curl.se/docs/copyright.html)
- [miniaudio.h](https://github.com/mackron/miniaudio) - Single-header audio format decoder, used by multimodal subsystem - Public domain
- [subprocess.h](https://github.com/sheredom/subprocess.h) - Single-header process launching solution for C and C++ - Public domain

View File

@ -105,7 +105,20 @@ if [ ! -z ${GG_BUILD_VULKAN} ]; then
fi
if [ ! -z ${GG_BUILD_WEBGPU} ]; then
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_WEBGPU=1"
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_WEBGPU=1 -DGGML_METAL=OFF -DGGML_BLAS=OFF"
if [ ! -z "${GG_BUILD_WEBGPU_DAWN_PREFIX}" ]; then
if [ -z "${CMAKE_PREFIX_PATH}" ]; then
export CMAKE_PREFIX_PATH="${GG_BUILD_WEBGPU_DAWN_PREFIX}"
else
export CMAKE_PREFIX_PATH="${GG_BUILD_WEBGPU_DAWN_PREFIX}:${CMAKE_PREFIX_PATH}"
fi
fi
# For some systems, Dawn_DIR needs to be set explicitly, e.g., the lib64 path
if [ ! -z "${GG_BUILD_WEBGPU_DAWN_DIR}" ]; then
CMAKE_EXTRA="${CMAKE_EXTRA} -DDawn_DIR=${GG_BUILD_WEBGPU_DAWN_DIR}"
fi
fi
if [ ! -z ${GG_BUILD_MUSA} ]; then

View File

@ -6,6 +6,7 @@
#include "log.h"
#include "sampling.h"
#include "download.h"
#include "preset.h"
// fix problem with std::min and std::max
#if defined(_WIN32)
@ -268,6 +269,46 @@ static void parse_tensor_buffer_overrides(const std::string & value, std::vector
}
}
static std::string clean_file_name(const std::string & fname) {
std::string clean_fname = fname;
string_replace_all(clean_fname, "\\", "_");
string_replace_all(clean_fname, "/", "_");
return clean_fname;
}
static bool common_params_handle_remote_preset(common_params & params, llama_example ex) {
GGML_ASSERT(!params.model.hf_repo.empty());
const bool offline = params.offline;
std::string model_endpoint = get_model_endpoint();
auto preset_url = model_endpoint + params.model.hf_repo + "/resolve/main/preset.ini";
// prepare local path for caching
auto preset_fname = clean_file_name(params.model.hf_repo + "_preset.ini");
auto preset_path = fs_get_cache_file(preset_fname);
const int status = common_download_file_single(preset_url, preset_path, params.hf_token, offline);
const bool has_preset = status >= 200 && status < 400;
// remote preset is optional, so we don't error out if not found
if (has_preset) {
LOG_INF("applying remote preset from %s\n", preset_url.c_str());
common_preset_context ctx(ex, /* only_remote_allowed */ true);
common_preset global; // unused for now
auto remote_presets = ctx.load_from_ini(preset_path, global);
if (remote_presets.find(COMMON_PRESET_DEFAULT_NAME) != remote_presets.end()) {
common_preset & preset = remote_presets.at(COMMON_PRESET_DEFAULT_NAME);
LOG_INF("\n%s", preset.to_ini().c_str()); // to_ini already added trailing newline
preset.apply_to_params(params);
} else {
throw std::runtime_error("Remote preset.ini does not contain [" + std::string(COMMON_PRESET_DEFAULT_NAME) + "] section");
}
} else {
LOG_INF("%s", "no remote preset found, skipping\n");
}
return has_preset;
}
struct handle_model_result {
bool found_mmproj = false;
common_params_model mmproj;
@ -309,9 +350,7 @@ static handle_model_result common_params_handle_model(
// make sure model path is present (for caching purposes)
if (model.path.empty()) {
// this is to avoid different repo having same file name, or same file name in different subdirs
std::string filename = model.hf_repo + "_" + model.hf_file;
// to make sure we don't have any slashes in the filename
string_replace_all(filename, "/", "_");
std::string filename = clean_file_name(model.hf_repo + "_" + model.hf_file);
model.path = fs_get_cache_file(filename);
}
@ -425,61 +464,87 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
}
};
std::set<std::string> seen_args;
auto parse_cli_args = [&]() {
std::set<std::string> seen_args;
for (int i = 1; i < argc; i++) {
const std::string arg_prefix = "--";
for (int i = 1; i < argc; i++) {
const std::string arg_prefix = "--";
std::string arg = argv[i];
if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
std::replace(arg.begin(), arg.end(), '_', '-');
}
if (arg_to_options.find(arg) == arg_to_options.end()) {
throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str()));
}
if (!seen_args.insert(arg).second) {
LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str());
}
auto & tmp = arg_to_options[arg];
auto opt = *tmp.first;
bool is_positive = tmp.second;
if (opt.has_value_from_env()) {
fprintf(stderr, "warn: %s environment variable is set, but will be overwritten by command line argument %s\n", opt.env, arg.c_str());
}
try {
if (opt.handler_void) {
opt.handler_void(params);
continue;
std::string arg = argv[i];
if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
std::replace(arg.begin(), arg.end(), '_', '-');
}
if (opt.handler_bool) {
opt.handler_bool(params, is_positive);
continue;
if (arg_to_options.find(arg) == arg_to_options.end()) {
throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str()));
}
if (!seen_args.insert(arg).second) {
LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str());
}
auto & tmp = arg_to_options[arg];
auto opt = *tmp.first;
bool is_positive = tmp.second;
if (opt.has_value_from_env()) {
fprintf(stderr, "warn: %s environment variable is set, but will be overwritten by command line argument %s\n", opt.env, arg.c_str());
}
try {
if (opt.handler_void) {
opt.handler_void(params);
continue;
}
if (opt.handler_bool) {
opt.handler_bool(params, is_positive);
continue;
}
// arg with single value
check_arg(i);
std::string val = argv[++i];
if (opt.handler_int) {
opt.handler_int(params, std::stoi(val));
continue;
}
if (opt.handler_string) {
opt.handler_string(params, val);
continue;
}
// arg with single value
check_arg(i);
std::string val = argv[++i];
if (opt.handler_int) {
opt.handler_int(params, std::stoi(val));
continue;
}
if (opt.handler_string) {
opt.handler_string(params, val);
continue;
}
// arg with 2 values
check_arg(i);
std::string val2 = argv[++i];
if (opt.handler_str_str) {
opt.handler_str_str(params, val, val2);
continue;
// arg with 2 values
check_arg(i);
std::string val2 = argv[++i];
if (opt.handler_str_str) {
opt.handler_str_str(params, val, val2);
continue;
}
} catch (std::exception & e) {
throw std::invalid_argument(string_format(
"error while handling argument \"%s\": %s\n\n"
"usage:\n%s\n\nto show complete usage, run with -h",
arg.c_str(), e.what(), opt.to_string().c_str()));
}
} catch (std::exception & e) {
throw std::invalid_argument(string_format(
"error while handling argument \"%s\": %s\n\n"
"usage:\n%s\n\nto show complete usage, run with -h",
arg.c_str(), e.what(), opt.to_string().c_str()));
}
};
// parse the first time to get -hf option (used for remote preset)
parse_cli_args();
// maybe handle remote preset
if (!params.model.hf_repo.empty()) {
std::string cli_hf_repo = params.model.hf_repo;
bool has_preset = common_params_handle_remote_preset(params, ctx_arg.ex);
// special case: if hf_repo explicitly set by preset, we need to preserve it (ignore CLI value)
// this is useful when we have one HF repo pointing to other HF repos (one model - multiple GGUFs)
std::string preset_hf_repo = params.model.hf_repo;
bool preset_has_hf_repo = preset_hf_repo != cli_hf_repo;
if (has_preset) {
// re-parse CLI args to override preset values
parse_cli_args();
}
// preserve hf_repo from preset if needed
if (preset_has_hf_repo) {
params.model.hf_repo = preset_hf_repo;
}
}
@ -679,7 +744,6 @@ static void common_params_print_completion(common_params_context & ctx_arg) {
"llama-quantize",
"llama-qwen2vl-cli",
"llama-retrieval",
"llama-run",
"llama-save-load-state",
"llama-server",
"llama-simple",
@ -1445,7 +1509,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, bool value) {
params.warmup = value;
}
).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_PERPLEXITY}));
).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_DEBUG}));
add_opt(common_arg(
{"--spm-infill"},
string_format(
@ -1761,7 +1825,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
else if (value == "rank") { params.pooling_type = LLAMA_POOLING_TYPE_RANK; }
else { throw std::invalid_argument("invalid value"); }
}
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_POOLING"));
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_DEBUG}).set_env("LLAMA_ARG_POOLING"));
add_opt(common_arg(
{"--attention"}, "{causal,non-causal}",
"attention type for embeddings, use model default if unspecified",
@ -2089,11 +2153,22 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
add_opt(common_arg(
{"--mmap"},
{"--no-mmap"},
string_format("whether to memory-map model (if disabled, slower load but may reduce pageouts if not using mlock) (default: %s)", params.use_mmap ? "enabled" : "disabled"),
string_format("whether to memory-map model. Explicitly enabling mmap disables direct-io. (if mmap disabled, slower load but may reduce pageouts if not using mlock) (default: %s)", params.use_mmap ? "enabled" : "disabled"),
[](common_params & params, bool value) {
params.use_mmap = value;
if (value) {
params.use_direct_io = false; // disable direct io when mmap is explicitly enabled
}
}
).set_env("LLAMA_ARG_MMAP"));
add_opt(common_arg(
{"-dio", "--direct-io"},
{"-ndio", "--no-direct-io"},
string_format("use DirectIO if available. Takes precedence over --mmap (default: %s)", params.use_direct_io ? "enabled" : "disabled"),
[](common_params & params, bool value) {
params.use_direct_io = value;
}
).set_env("LLAMA_ARG_DIO"));
add_opt(common_arg(
{"--numa"}, "TYPE",
"attempt optimizations that help on some NUMA systems\n"
@ -2245,7 +2320,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
std::vector<std::string> split_arg{ it, {} };
if (split_arg.size() >= llama_max_devices()) {
throw std::invalid_argument(
string_format("got %d input configs, but system only has %d devices", (int)split_arg.size(), (int)llama_max_devices())
string_format("got %zu input configs, but system only has %zu devices", split_arg.size(), llama_max_devices())
);
}
for (size_t i = 0; i < llama_max_devices(); ++i) {
@ -2285,10 +2360,28 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_env("LLAMA_ARG_FIT"));
add_opt(common_arg(
{ "-fitt", "--fit-target" }, "MiB",
string_format("target margin per device for --fit option, default: %zu", params.fit_params_target/(1024*1024)),
[](common_params & params, int value) {
params.fit_params_target = value * size_t(1024*1024);
{ "-fitt", "--fit-target" }, "MiB0,MiB1,MiB2,...",
string_format("target margin per device for --fit, comma-separated list of values, "
"single value is broadcast across all devices, default: %zu", params.fit_params_target[0]/(1024*1024)),
[](common_params & params, const std::string & value) {
std::string arg_next = value;
// split string by , and /
const std::regex regex{ R"([,/]+)" };
std::sregex_token_iterator it{ arg_next.begin(), arg_next.end(), regex, -1 };
std::vector<std::string> split_arg{ it, {} };
if (split_arg.size() >= llama_max_devices()) {
throw std::invalid_argument(
string_format("got %zu input configs, but system only has %zu devices", split_arg.size(), llama_max_devices())
);
}
if (split_arg.size() == 1) {
std::fill(params.fit_params_target.begin(), params.fit_params_target.end(), std::stoul(split_arg[0]) * 1024*1024);
return;
}
for (size_t i = 0; i < split_arg.size(); i++) {
params.fit_params_target[i] = std::stoul(split_arg[i]) * 1024*1024;
}
}
).set_env("LLAMA_ARG_FIT_TARGET"));
add_opt(common_arg(
@ -2609,7 +2702,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, int value) {
params.embd_normalize = value;
}
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_DEBUG}));
add_opt(common_arg(
{"--embd-output-format"}, "FORMAT",
"empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix, \"raw\" = plain whitespace-delimited output (one embedding per line)",
@ -2687,7 +2780,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params) {
params.embedding = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS"));
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_DEBUG}).set_env("LLAMA_ARG_EMBEDDINGS"));
add_opt(common_arg(
{"--rerank", "--reranking"},
string_format("enable reranking endpoint on server (default: %s)", "disabled"),
@ -3378,6 +3471,27 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
}
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg(
{"--save-logits"},
string_format("save final logits to files for verification (default: %s)", params.save_logits ? "true" : "false"),
[](common_params & params) {
params.save_logits = true;
}
).set_examples({LLAMA_EXAMPLE_DEBUG}));
add_opt(common_arg(
{"--logits-output-dir"}, "PATH",
string_format("directory for saving logits output files (default: %s)", params.logits_output_dir.c_str()),
[](common_params & params, const std::string & value) {
params.logits_output_dir = value;
}
).set_examples({LLAMA_EXAMPLE_DEBUG}));
add_opt(common_arg(
{"--tensor-filter"}, "REGEX",
"filter tensor names for debug output (regex pattern, can be specified multiple times)",
[](common_params & params, const std::string & value) {
params.tensor_filter.push_back(value);
}
).set_examples({LLAMA_EXAMPLE_DEBUG}));
// presets
add_opt(common_arg(

View File

@ -129,11 +129,3 @@ void common_params_add_preset_options(std::vector<common_arg> & args);
// initialize argument parser context - used by test-arg-parser and preset
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
struct common_remote_params {
std::vector<std::string> headers;
long timeout = 0; // CURLOPT_TIMEOUT, in seconds ; 0 means no timeout
long max_size = 0; // max size of the response ; unlimited if 0 ; max is 2GB
};
// get remote file content, returns <http_code, raw_response_body>
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const common_remote_params & params);

View File

@ -1097,7 +1097,7 @@ common_init_result::common_init_result(common_params & params) :
if (params.fit_params) {
LOG_INF("%s: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on\n", __func__);
llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target, params.fit_params_min_ctx,
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target.data(), params.fit_params_min_ctx,
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
}
@ -1366,6 +1366,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
mparams.split_mode = params.split_mode;
mparams.tensor_split = params.tensor_split;
mparams.use_mmap = params.use_mmap;
mparams.use_direct_io = params.use_direct_io;
mparams.use_mlock = params.use_mlock;
mparams.check_tensors = params.check_tensors;
mparams.use_extra_bufts = !params.no_extra_bufts;

View File

@ -80,6 +80,7 @@ int32_t cpu_get_num_math();
//
enum llama_example {
LLAMA_EXAMPLE_DEBUG,
LLAMA_EXAMPLE_COMMON,
LLAMA_EXAMPLE_SPECULATIVE,
LLAMA_EXAMPLE_COMPLETION,
@ -331,12 +332,14 @@ struct common_params {
// offload params
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
int32_t n_gpu_layers = -1; // number of layers to store in VRAM, -1 is auto, <= -2 is all
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
bool fit_params = true; // whether to fit unset model/context parameters to free device memory
size_t fit_params_target = 1024 * 1024*1024; // margin per device in bytes for fitting parameters to free memory
int32_t fit_params_min_ctx = 4096; // minimum context size to set when trying to reduce memory use
int32_t n_gpu_layers = -1; // number of layers to store in VRAM, -1 is auto, <= -2 is all
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
bool fit_params = true; // whether to fit unset model/context parameters to free device memory
int32_t fit_params_min_ctx = 4096; // minimum context size to set when trying to reduce memory use
// margin per device in bytes for fitting parameters to free memory:
std::vector<size_t> fit_params_target = std::vector<size_t>(llama_max_devices(), 1024 * 1024*1024);
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
@ -372,6 +375,11 @@ struct common_params {
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT
std::string logits_file = ""; // file for saving *all* logits // NOLINT
// llama-debug specific options
std::string logits_output_dir = "data"; // directory for saving logits output files // NOLINT
bool save_logits = false; // whether to save logits to files // NOLINT
std::vector<std::string> tensor_filter; // filter tensor names for debug output (regex) // NOLINT
std::vector<std::string> in_files; // all input files
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
std::vector<llama_model_kv_override> kv_overrides;
@ -422,7 +430,8 @@ struct common_params {
bool kv_unified = false; // enable unified KV cache
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool use_mmap = true; // use mmap for faster loads
bool use_mmap = true; // enable mmap to use filesystem cache
bool use_direct_io = true; // read from disk without buffering for faster model loading
bool use_mlock = false; // use mlock to keep model in memory
bool verbose_prompt = false; // print prompt tokens before generation
bool display_prompt = true; // print prompt before generation

View File

@ -157,6 +157,10 @@ static std::string read_etag(const std::string & path) {
return none;
}
static bool is_http_status_ok(int status) {
return status >= 200 && status < 400;
}
#ifdef LLAMA_USE_CURL
//
@ -306,11 +310,14 @@ static bool common_download_head(CURL * curl,
}
// download one single file from remote URL to local path
static bool common_download_file_single_online(const std::string & url,
// returns status code or -1 on error
static int common_download_file_single_online(const std::string & url,
const std::string & path,
const std::string & bearer_token) {
const std::string & bearer_token,
const common_header_list & custom_headers) {
static const int max_attempts = 3;
static const int retry_delay_seconds = 2;
for (int i = 0; i < max_attempts; ++i) {
std::string etag;
@ -330,6 +337,11 @@ static bool common_download_file_single_online(const std::string & url,
common_load_model_from_url_headers headers;
curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
curl_slist_ptr http_headers;
for (const auto & h : custom_headers) {
std::string s = h.first + ": " + h.second;
http_headers.ptr = curl_slist_append(http_headers.ptr, s.c_str());
}
const bool was_perform_successful = common_download_head(curl.get(), http_headers, url, bearer_token);
if (!was_perform_successful) {
head_request_ok = false;
@ -365,7 +377,7 @@ static bool common_download_file_single_online(const std::string & url,
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
if (remove(path.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
return false;
return -1;
}
}
@ -374,14 +386,14 @@ static bool common_download_file_single_online(const std::string & url,
if (std::filesystem::exists(path_temporary)) {
if (remove(path_temporary.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
return false;
return -1;
}
}
if (std::filesystem::exists(path)) {
if (remove(path.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
return false;
return -1;
}
}
}
@ -408,23 +420,27 @@ static bool common_download_file_single_online(const std::string & url,
long http_code = 0;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
if (http_code < 200 || http_code >= 400) {
int status = static_cast<int>(http_code);
if (!is_http_status_ok(http_code)) {
LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code);
return false;
return status; // TODO: maybe only return on certain codes
}
if (rename(path_temporary.c_str(), path.c_str()) != 0) {
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
return false;
return -1;
}
return static_cast<int>(http_code);
} else {
LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
}
break;
return 304; // Not Modified - fake cached response
}
}
return true;
return -1; // max attempts reached
}
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const common_remote_params & params) {
@ -454,8 +470,10 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
curl_easy_setopt(curl.get(), CURLOPT_MAXFILESIZE, params.max_size);
}
http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
for (const auto & header : params.headers) {
http_headers.ptr = curl_slist_append(http_headers.ptr, header.c_str());
std::string header_ = header.first + ": " + header.second;
http_headers.ptr = curl_slist_append(http_headers.ptr, header_.c_str());
}
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
@ -617,9 +635,11 @@ static bool common_pull_file(httplib::Client & cli,
}
// download one single file from remote URL to local path
static bool common_download_file_single_online(const std::string & url,
// returns status code or -1 on error
static int common_download_file_single_online(const std::string & url,
const std::string & path,
const std::string & bearer_token) {
const std::string & bearer_token,
const common_header_list & custom_headers) {
static const int max_attempts = 3;
static const int retry_delay_seconds = 2;
@ -629,6 +649,9 @@ static bool common_download_file_single_online(const std::string & url,
if (!bearer_token.empty()) {
default_headers.insert({"Authorization", "Bearer " + bearer_token});
}
for (const auto & h : custom_headers) {
default_headers.emplace(h.first, h.second);
}
cli.set_default_headers(default_headers);
const bool file_exists = std::filesystem::exists(path);
@ -647,8 +670,10 @@ static bool common_download_file_single_online(const std::string & url,
LOG_WRN("%s: HEAD invalid http status code received: %d\n", __func__, head ? head->status : -1);
if (file_exists) {
LOG_INF("%s: Using cached file (HEAD failed): %s\n", __func__, path.c_str());
return true;
return 304; // 304 Not Modified - fake cached response
}
return head->status; // cannot use cached file, return raw status code
// TODO: maybe retry only on certain codes
}
std::string etag;
@ -680,12 +705,12 @@ static bool common_download_file_single_online(const std::string & url,
if (file_exists) {
if (!should_download_from_scratch) {
LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
return true;
return 304; // 304 Not Modified - fake cached response
}
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
if (remove(path.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
return false;
return -1;
}
}
@ -697,7 +722,7 @@ static bool common_download_file_single_online(const std::string & url,
existing_size = std::filesystem::file_size(path_temporary);
} else if (remove(path_temporary.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
return false;
return -1;
}
}
@ -718,15 +743,16 @@ static bool common_download_file_single_online(const std::string & url,
if (std::rename(path_temporary.c_str(), path.c_str()) != 0) {
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
return false;
return -1;
}
if (!etag.empty()) {
write_etag(path, etag);
}
break;
return head->status; // TODO: use actual GET status?
}
return true;
return -1; // max attempts reached
}
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url,
@ -734,13 +760,9 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string
auto [cli, parts] = common_http_client(url);
httplib::Headers headers = {{"User-Agent", "llama-cpp"}};
for (const auto & header : params.headers) {
size_t pos = header.find(':');
if (pos != std::string::npos) {
headers.emplace(header.substr(0, pos), header.substr(pos + 1));
} else {
headers.emplace(header, "");
}
headers.emplace(header.first, header.second);
}
if (params.timeout > 0) {
@ -769,32 +791,45 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string
#if defined(LLAMA_USE_CURL) || defined(LLAMA_USE_HTTPLIB)
static bool common_download_file_single(const std::string & url,
const std::string & path,
const std::string & bearer_token,
bool offline) {
int common_download_file_single(const std::string & url,
const std::string & path,
const std::string & bearer_token,
bool offline,
const common_header_list & headers) {
if (!offline) {
return common_download_file_single_online(url, path, bearer_token);
return common_download_file_single_online(url, path, bearer_token, headers);
}
if (!std::filesystem::exists(path)) {
LOG_ERR("%s: required file is not available in cache (offline mode): %s\n", __func__, path.c_str());
return false;
return -1;
}
LOG_INF("%s: using cached file (offline mode): %s\n", __func__, path.c_str());
return true;
return 304; // Not Modified - fake cached response
}
// download multiple files from remote URLs to local paths
// the input is a vector of pairs <url, path>
static bool common_download_file_multiple(const std::vector<std::pair<std::string, std::string>> & urls, const std::string & bearer_token, bool offline) {
static bool common_download_file_multiple(const std::vector<std::pair<std::string, std::string>> & urls,
const std::string & bearer_token,
bool offline,
const common_header_list & headers) {
// Prepare download in parallel
std::vector<std::future<bool>> futures_download;
futures_download.reserve(urls.size());
for (auto const & item : urls) {
futures_download.push_back(std::async(std::launch::async, [bearer_token, offline](const std::pair<std::string, std::string> & it) -> bool {
return common_download_file_single(it.first, it.second, bearer_token, offline);
}, item));
futures_download.push_back(
std::async(
std::launch::async,
[&bearer_token, offline, &headers](const std::pair<std::string, std::string> & it) -> bool {
const int http_status = common_download_file_single(it.first, it.second, bearer_token, offline, headers);
return is_http_status_ok(http_status);
},
item
)
);
}
// Wait for all downloads to complete
@ -807,17 +842,18 @@ static bool common_download_file_multiple(const std::vector<std::pair<std::strin
return true;
}
bool common_download_model(
const common_params_model & model,
const std::string & bearer_token,
bool offline) {
bool common_download_model(const common_params_model & model,
const std::string & bearer_token,
bool offline,
const common_header_list & headers) {
// Basic validation of the model.url
if (model.url.empty()) {
LOG_ERR("%s: invalid model url\n", __func__);
return false;
}
if (!common_download_file_single(model.url, model.path, bearer_token, offline)) {
const int http_status = common_download_file_single(model.url, model.path, bearer_token, offline, headers);
if (!is_http_status_ok(http_status)) {
return false;
}
@ -876,13 +912,16 @@ bool common_download_model(
}
// Download in parallel
common_download_file_multiple(urls, bearer_token, offline);
common_download_file_multiple(urls, bearer_token, offline, headers);
}
return true;
}
common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & bearer_token, bool offline) {
common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag,
const std::string & bearer_token,
bool offline,
const common_header_list & custom_headers) {
auto parts = string_split<std::string>(hf_repo_with_tag, ':');
std::string tag = parts.size() > 1 ? parts.back() : "latest";
std::string hf_repo = parts[0];
@ -893,10 +932,10 @@ common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, cons
std::string url = get_model_endpoint() + "v2/" + hf_repo + "/manifests/" + tag;
// headers
std::vector<std::string> headers;
headers.push_back("Accept: application/json");
common_header_list headers = custom_headers;
headers.push_back({"Accept", "application/json"});
if (!bearer_token.empty()) {
headers.push_back("Authorization: Bearer " + bearer_token);
headers.push_back({"Authorization", "Bearer " + bearer_token});
}
// Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
// User-Agent header is already set in common_remote_get_content, no need to set it here
@ -952,7 +991,7 @@ common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, cons
} else if (res_code == 401) {
throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token");
} else {
throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str()));
throw std::runtime_error(string_format("error from HF API (%s), response code: %ld, data: %s", url.c_str(), res_code, res_str.c_str()));
}
// check response
@ -1031,9 +1070,10 @@ std::string common_docker_resolve_model(const std::string & docker) {
const std::string url_prefix = "https://registry-1.docker.io/v2/" + repo;
std::string manifest_url = url_prefix + "/manifests/" + tag;
common_remote_params manifest_params;
manifest_params.headers.push_back("Authorization: Bearer " + token);
manifest_params.headers.push_back(
"Accept: application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json");
manifest_params.headers.push_back({"Authorization", "Bearer " + token});
manifest_params.headers.push_back({"Accept",
"application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json"
});
auto manifest_res = common_remote_get_content(manifest_url, manifest_params);
if (manifest_res.first != 200) {
throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first));
@ -1070,7 +1110,8 @@ std::string common_docker_resolve_model(const std::string & docker) {
std::string local_path = fs_get_cache_file(model_filename);
const std::string blob_url = url_prefix + "/blobs/" + gguf_digest;
if (!common_download_file_single(blob_url, local_path, token, false)) {
const int http_status = common_download_file_single(blob_url, local_path, token, false, {});
if (!is_http_status_ok(http_status)) {
throw std::runtime_error("Failed to download Docker Model");
}
@ -1084,11 +1125,11 @@ std::string common_docker_resolve_model(const std::string & docker) {
#else
common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool) {
common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool, const common_header_list &) {
throw std::runtime_error("download functionality is not enabled in this build");
}
bool common_download_model(const common_params_model &, const std::string &, bool) {
bool common_download_model(const common_params_model &, const std::string &, bool, const common_header_list &) {
throw std::runtime_error("download functionality is not enabled in this build");
}
@ -1096,6 +1137,14 @@ std::string common_docker_resolve_model(const std::string &) {
throw std::runtime_error("download functionality is not enabled in this build");
}
int common_download_file_single(const std::string &,
const std::string &,
const std::string &,
bool,
const common_header_list &) {
throw std::runtime_error("download functionality is not enabled in this build");
}
#endif // LLAMA_USE_CURL || LLAMA_USE_HTTPLIB
std::vector<common_cached_model_info> common_list_cached_models() {

View File

@ -1,12 +1,21 @@
#pragma once
#include <string>
#include <vector>
struct common_params_model;
//
// download functionalities
//
using common_header = std::pair<std::string, std::string>;
using common_header_list = std::vector<common_header>;
struct common_remote_params {
common_header_list headers;
long timeout = 0; // in seconds, 0 means no timeout
long max_size = 0; // unlimited if 0
};
// get remote file content, returns <http_code, raw_response_body>
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const common_remote_params & params);
struct common_cached_model_info {
std::string manifest_path;
@ -41,17 +50,29 @@ struct common_hf_file_res {
common_hf_file_res common_get_hf_file(
const std::string & hf_repo_with_tag,
const std::string & bearer_token,
bool offline);
bool offline,
const common_header_list & headers = {}
);
// returns true if download succeeded
bool common_download_model(
const common_params_model & model,
const std::string & bearer_token,
bool offline);
bool offline,
const common_header_list & headers = {}
);
// returns list of cached models
std::vector<common_cached_model_info> common_list_cached_models();
// download single file from url to local path
// returns status code or -1 on error
int common_download_file_single(const std::string & url,
const std::string & path,
const std::string & bearer_token,
bool offline,
const common_header_list & headers = {});
// resolve and download model from Docker registry
// return local path to downloaded model file
std::string common_docker_resolve_model(const std::string & docker);

View File

@ -16,6 +16,46 @@ static std::string rm_leading_dashes(const std::string & str) {
return str.substr(pos);
}
// only allow a subset of args for remote presets for security reasons
// do not add more args unless absolutely necessary
// args that output to files are strictly prohibited
static std::set<std::string> get_remote_preset_whitelist(const std::map<std::string, common_arg> & key_to_opt) {
static const std::set<std::string> allowed_options = {
"model-url",
"hf-repo",
"hf-repo-draft",
"hf-repo-v", // vocoder
"hf-file-v", // vocoder
"mmproj-url",
"pooling",
"jinja",
"batch-size",
"ubatch-size",
"cache-reuse",
// note: sampling params are automatically allowed by default
// negated args will be added automatically
};
std::set<std::string> allowed_keys;
for (const auto & it : key_to_opt) {
const std::string & key = it.first;
const common_arg & opt = it.second;
if (allowed_options.find(key) != allowed_options.end() || opt.is_sparam) {
allowed_keys.insert(key);
// also add variant keys (args without leading dashes and env vars)
for (const auto & arg : opt.get_args()) {
allowed_keys.insert(rm_leading_dashes(arg));
}
for (const auto & env : opt.get_env()) {
allowed_keys.insert(env);
}
}
}
return allowed_keys;
}
std::vector<std::string> common_preset::to_args(const std::string & bin_path) const {
std::vector<std::string> args;
@ -121,6 +161,29 @@ void common_preset::merge(const common_preset & other) {
}
}
void common_preset::apply_to_params(common_params & params) const {
for (const auto & [opt, val] : options) {
// apply each option to params
if (opt.handler_string) {
opt.handler_string(params, val);
} else if (opt.handler_int) {
opt.handler_int(params, std::stoi(val));
} else if (opt.handler_bool) {
opt.handler_bool(params, common_arg_utils::is_truthy(val));
} else if (opt.handler_str_str) {
// not supported yet
throw std::runtime_error(string_format(
"%s: option with two values is not supported yet",
__func__
));
} else if (opt.handler_void) {
opt.handler_void(params);
} else {
GGML_ABORT("unknown handler type");
}
}
}
static std::map<std::string, std::map<std::string, std::string>> parse_ini_from_file(const std::string & path) {
std::map<std::string, std::map<std::string, std::string>> parsed;
@ -230,10 +293,16 @@ static std::string parse_bool_arg(const common_arg & arg, const std::string & ke
return value;
}
common_preset_context::common_preset_context(llama_example ex)
common_preset_context::common_preset_context(llama_example ex, bool only_remote_allowed)
: ctx_params(common_params_parser_init(default_params, ex)) {
common_params_add_preset_options(ctx_params.options);
key_to_opt = get_map_key_opt(ctx_params);
// setup allowed keys if only_remote_allowed is true
if (only_remote_allowed) {
filter_allowed_keys = true;
allowed_keys = get_remote_preset_whitelist(key_to_opt);
}
}
common_presets common_preset_context::load_from_ini(const std::string & path, common_preset & global) const {
@ -250,6 +319,12 @@ common_presets common_preset_context::load_from_ini(const std::string & path, co
LOG_DBG("loading preset: %s\n", preset.name.c_str());
for (const auto & [key, value] : section.second) {
LOG_DBG("option: %s = %s\n", key.c_str(), value.c_str());
if (filter_allowed_keys && allowed_keys.find(key) == allowed_keys.end()) {
throw std::runtime_error(string_format(
"option '%s' is not allowed in remote presets",
key.c_str()
));
}
if (key_to_opt.find(key) != key_to_opt.end()) {
const auto & opt = key_to_opt.at(key);
if (is_bool_arg(opt)) {

View File

@ -6,6 +6,7 @@
#include <string>
#include <vector>
#include <map>
#include <set>
//
// INI preset parser and writer
@ -40,6 +41,9 @@ struct common_preset {
// merge another preset into this one, overwriting existing options
void merge(const common_preset & other);
// apply preset options to common_params
void apply_to_params(common_params & params) const;
};
// interface for multiple presets in one file
@ -50,7 +54,12 @@ struct common_preset_context {
common_params default_params; // unused for now
common_params_context ctx_params;
std::map<std::string, common_arg> key_to_opt;
common_preset_context(llama_example ex);
bool filter_allowed_keys = false;
std::set<std::string> allowed_keys;
// if only_remote_allowed is true, only accept whitelisted keys
common_preset_context(llama_example ex, bool only_remote_allowed = false);
// load presets from INI file
common_presets load_from_ini(const std::string & path, common_preset & global) const;

View File

@ -775,8 +775,8 @@ class TextModel(ModelBase):
self.rope_parameters = self.hparams.get("rope_parameters", self.hparams.get("rope_scaling")) or {}
rope_theta = self.find_hparam(["rope_theta", "global_rope_theta", "rotary_emb_base"], optional=True)
local_rope_theta = self.find_hparam(["local_rope_theta", "rope_local_theta", "swa_rope_theta", "rope_local_base_freq"], optional=True)
rope_theta = self.find_hparam(["global_rope_theta", "rope_global_theta", "rope_theta_global", "rope_theta", "rotary_emb_base"], optional=True)
local_rope_theta = self.find_hparam(["local_rope_theta", "rope_local_theta", "rope_theta_local", "swa_rope_theta", "rope_local_base_freq"], optional=True)
# Ensure "rope_theta" and "rope_type" is mirrored in rope_parameters
if "full_attention" not in self.rope_parameters and "sliding_attention" not in self.rope_parameters:
@ -11234,8 +11234,8 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--sentence-transformers-dense-modules", action="store_true",
help=("Whether to include sentence-transformers dense modules."
"It can be used for sentence-transformers models, like google/embeddinggemma-300m"
help=("Whether to include sentence-transformers dense modules. "
"It can be used for sentence-transformers models, like google/embeddinggemma-300m. "
"Default these modules are not included.")
)

60
docs/preset.md Normal file
View File

@ -0,0 +1,60 @@
# llama.cpp INI Presets
## Introduction
The INI preset feature, introduced in [PR#17859](https://github.com/ggml-org/llama.cpp/pull/17859), allows users to create reusable and shareable parameter configurations for llama.cpp.
### Using Presets with the Server
When running multiple models on the server (router mode), INI preset files can be used to configure model-specific parameters. Please refer to the [server documentation](../tools/server/README.md) for more details.
### Using a Remote Preset
> [!NOTE]
>
> This feature is currently only supported via the `-hf` option.
For GGUF models hosted on Hugging Face, you can include a `preset.ini` file in the root directory of the repository to define specific configurations for that model.
Example:
```ini
hf-repo-draft = username/my-draft-model-GGUF
temp = 0.5
top-k = 20
top-p = 0.95
```
For security reasons, only certain options are allowed. Please refer to [preset.cpp](../common/preset.cpp) for the complete list of permitted options.
Example usage:
Assuming your repository `username/my-model-with-preset` contains a `preset.ini` with the configuration above:
```sh
llama-cli -hf username/my-model-with-preset
# This is equivalent to:
llama-cli -hf username/my-model-with-preset \
--hf-repo-draft username/my-draft-model-GGUF \
--temp 0.5 \
--top-k 20 \
--top-p 0.95
```
You can also override preset arguments by specifying them on the command line:
```sh
# Force temp = 0.1, overriding the preset value
llama-cli -hf username/my-model-with-preset --temp 0.1
```
If you want to define multiple preset configurations for one or more GGUF models, you can create a blank HF repo for each preset. Each HF repo should contain a `preset.ini` file that references the actual model(s):
```ini
hf-repo = user/my-model-main
hf-repo-draft = user/my-model-draft
temp = 0.8
ctx-size = 1024
; (and other configurations)
```

View File

@ -15,6 +15,7 @@ llama_add_compile_flags()
if (EMSCRIPTEN)
else()
add_subdirectory(batched)
add_subdirectory(debug)
add_subdirectory(embedding)
add_subdirectory(eval-callback)
@ -34,7 +35,6 @@ else()
add_subdirectory(gen-docs)
add_subdirectory(training)
add_subdirectory(diffusion)
add_subdirectory(model-conversion)
if (NOT GGML_BACKEND_DL)
add_subdirectory(convert-llama2c-to-ggml)
# these examples use the backends directly and cannot be built with dynamic loading

View File

@ -1,5 +1,5 @@
set(TARGET llama-logits)
add_executable(${TARGET} logits.cpp)
set(TARGET llama-debug)
add_executable(${TARGET} debug.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)

54
examples/debug/README.md Normal file
View File

@ -0,0 +1,54 @@
# llama.cpp/examples/debug
This is a utility intended to help debug a model by registering a callback that
logs GGML operations and tensor data. It can also store the generated logits or
embeddings as well as the prompt and token ids for comparision with the original
model.
### Usage
```shell
llama-debug \
--hf-repo ggml-org/models \
--hf-file phi-2/ggml-model-q4_0.gguf \
--model phi-2-q4_0.gguf \
--prompt hello \
--save-logits \
--verbose
```
The tensor data is logged as debug and required the --verbose flag. The reason
for this is that while useful for a model with many layers there can be a lot of
output. You can filter the tensor names using the `--tensor-filter` option.
A recommended approach is to first run without `--verbose` and see if the
generated logits/embeddings are close to the original model. If they are not,
then it might be required to inspect tensor by tensor and in that case it is
useful to enable the `--verbose` flag along with `--tensor-filter` to focus on
specific tensors.
### Options
This example supports all standard `llama.cpp` options and also accepts the
following options:
```console
$ llama-debug --help
...
----- example-specific params -----
--save-logits save final logits to files for verification (default: false)
--logits-output-dir PATH directory for saving logits output files (default: data)
--tensor-filter REGEX filter tensor names for debug output (regex pattern, can be specified multiple times)
```
### Output Files
When `--save-logits` is enabled, the following files are created in the output
directory:
* `llamacpp-<model>[-embeddings].bin` - Binary output (logits or embeddings)
* `llamacpp-<model>[-embeddings].txt` - Text output (logits or embeddings, one per line)
* `llamacpp-<model>[-embeddings]-prompt.txt` - Prompt text and token IDs
* `llamacpp-<model>[-embeddings]-tokens.bin` - Binary token IDs for programmatic comparison
These files can be compared against the original model's output to verify the
converted model.

421
examples/debug/debug.cpp Normal file
View File

@ -0,0 +1,421 @@
#include "arg.h"
#include "common.h"
#include "log.h"
#include "llama.h"
#include "ggml.h"
#include <cmath>
#include <cstdint>
#include <cstdlib>
#include <string>
#include <vector>
#include <filesystem>
#include <fstream>
#include <regex>
static void print_usage(int, char ** argv) {
const std::string usage_template = R"(
example usage:
Print tensors:
{prog} -m model.gguf -p "Hello my name is" --verbose
The tensors to be printed can be filtered with --tensor-filter option.
Save logits/embeddings:
{prog} -m model.gguf -p "Hello my name is" --save-logits
Add --embedding to save embeddings)" "\n";
// Fix the source code indentation above that is introduced by the raw string literal.
std::string usage = std::regex_replace(usage_template, std::regex("\\n {8}"), "\n");
usage = std::regex_replace(usage, std::regex("\\{prog\\}"), argv[0]);
LOG("%s\n", usage.c_str());
}
static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data);
struct callback_data {
std::vector<uint8_t> data;
std::vector<std::regex> tensor_filters;
callback_data() = default;
callback_data(common_params & params, const std::vector<std::string> & filter_patterns) {
for (const auto & pattern : filter_patterns) {
try {
std::string anchored_pattern = "^" + pattern;
tensor_filters.emplace_back(anchored_pattern, std::regex::optimize);
} catch (const std::regex_error & e) {
throw std::runtime_error("Invalid regex pattern '" + pattern + "': " + e.what());
}
}
params.cb_eval = ggml_debug;
params.cb_eval_user_data = this;
}
};
struct output_data {
float * data_ptr = nullptr;
int data_size = 0;
std::string type_suffix;
std::vector<float> storage;
std::string prompt;
std::vector<llama_token> tokens;
output_data(llama_context * ctx, const llama_model * model, const common_params & params) {
const llama_vocab * vocab = llama_model_get_vocab(model);
const bool add_bos = llama_vocab_get_add_bos(vocab);
tokens = common_tokenize(ctx, params.prompt, add_bos);
prompt = params.prompt;
if (params.embedding) {
const int n_embd = llama_model_n_embd_out(model);
const bool pooling_enabled = llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_NONE;
const int n_embd_count = pooling_enabled ? 1 : tokens.size();
const int n_embeddings = n_embd * n_embd_count;
float * embeddings;
if (pooling_enabled) {
embeddings = llama_get_embeddings_seq(ctx, 0);
storage.resize(n_embeddings);
common_embd_normalize(embeddings, storage.data(), n_embeddings, params.embd_normalize);
embeddings = storage.data();
} else {
embeddings = llama_get_embeddings(ctx);
}
data_ptr = embeddings;
data_size = n_embeddings;
type_suffix = "-embeddings";
} else {
const float * logits = llama_get_logits_ith(ctx, tokens.size() - 1);
const int n_logits = llama_vocab_n_tokens(vocab);
data_ptr = const_cast<float*>(logits);
data_size = n_logits;
type_suffix = "";
}
}
};
static std::string ggml_ne_string(const ggml_tensor * t) {
std::string str;
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
str += std::to_string(t->ne[i]);
if (i + 1 < GGML_MAX_DIMS) {
str += ", ";
}
}
return str;
}
static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
union {
float f;
uint32_t i;
} u;
u.i = (uint32_t)h.bits << 16;
return u.f;
}
static float ggml_get_float_value(const uint8_t * data, ggml_type type,
const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) {
size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
switch (type) {
case GGML_TYPE_F16:
return ggml_fp16_to_fp32(*(const ggml_fp16_t *) &data[i]);
case GGML_TYPE_F32:
return *(const float *) &data[i];
case GGML_TYPE_I64:
return (float) *(const int64_t *) &data[i];
case GGML_TYPE_I32:
return (float) *(const int32_t *) &data[i];
case GGML_TYPE_I16:
return (float) *(const int16_t *) &data[i];
case GGML_TYPE_I8:
return (float) *(const int8_t *) &data[i];
case GGML_TYPE_BF16:
return ggml_compute_bf16_to_fp32(*(const ggml_bf16_t *) &data[i]);
default:
GGML_ABORT("fatal error");
}
}
static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n) {
GGML_ASSERT(n > 0);
float sum = 0;
float sum_sq = 0.0;
for (int64_t i3 = 0; i3 < ne[3]; i3++) {
for (int64_t i2 = 0; i2 < ne[2]; i2++) {
for (int64_t i1 = 0; i1 < ne[1]; i1++) {
for (int64_t i0 = 0; i0 < ne[0]; i0++) {
const float v = ggml_get_float_value(data, type, nb, i0, i1, i2, i3);
sum += v;
sum_sq += v * v;
}
}
}
}
for (int64_t i3 = 0; i3 < ne[3]; i3++) {
LOG_DBG(" [\n");
for (int64_t i2 = 0; i2 < ne[2]; i2++) {
if (i2 == n && ne[2] > 2*n) {
LOG_DBG(" ..., \n");
i2 = ne[2] - n;
}
LOG_DBG(" [\n");
for (int64_t i1 = 0; i1 < ne[1]; i1++) {
if (i1 == n && ne[1] > 2*n) {
LOG_DBG(" ..., \n");
i1 = ne[1] - n;
}
LOG_DBG(" [");
for (int64_t i0 = 0; i0 < ne[0]; i0++) {
if (i0 == n && ne[0] > 2*n) {
LOG_DBG("..., ");
i0 = ne[0] - n;
}
const float v = ggml_get_float_value(data, type, nb, i0, i1, i2, i3);
LOG_DBG("%12.4f", v);
if (i0 < ne[0] - 1) {
LOG_DBG(", ");
}
}
LOG_DBG("],\n");
}
LOG_DBG(" ],\n");
}
LOG_DBG(" ]\n");
LOG_DBG(" sum = %f\n", sum);
LOG_DBG(" sum_sq = %f\n", sum_sq);
}
if (std::isnan(sum)) {
LOG_ERR("encountered NaN - aborting\n");
exit(0);
}
}
/**
* GGML operations callback during the graph execution.
*
* @param t current tensor
* @param ask when ask is true, the scheduler wants to know if we are interested in data from this tensor
* if we return true, a follow-up call will be made with ask=false in which we can do the actual collection.
* see ggml_backend_sched_eval_callback
* @param user_data user data to pass at each call back
* @return true to receive data or continue the graph, false otherwise
*/
static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) {
auto * cb_data = (callback_data *) user_data;
const struct ggml_tensor * src0 = t->src[0];
const struct ggml_tensor * src1 = t->src[1];
if (ask) {
return true; // Always retrieve data
}
bool matches_filter = cb_data->tensor_filters.empty();
if (!matches_filter) {
for (const auto & filter : cb_data->tensor_filters) {
if (std::regex_search(t->name, filter)) {
matches_filter = true;
break;
}
}
}
char src1_str[128] = {0};
if (src1) {
snprintf(src1_str, sizeof(src1_str), "%s{%s}", src1->name, ggml_ne_string(src1).c_str());
}
if (matches_filter) {
LOG_DBG("%s: %24s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__,
t->name,
ggml_type_name(t->type),
ggml_op_desc(t),
src0->name,
ggml_ne_string(src0).c_str(),
src1 ? src1_str : "",
ggml_ne_string(t).c_str());
}
const bool is_host = ggml_backend_buffer_is_host(t->buffer);
if (!is_host) {
auto n_bytes = ggml_nbytes(t);
cb_data->data.resize(n_bytes);
ggml_backend_tensor_get(t, cb_data->data.data(), 0, n_bytes);
}
if (!ggml_is_quantized(t->type) && matches_filter) {
uint8_t * data = is_host ? (uint8_t *) t->data : cb_data->data.data();
ggml_print_tensor(data, t->type, t->ne, t->nb, 3);
}
return true;
}
static void save_output_data(const output_data & output, const std::string & model_name, const std::string & output_dir) {
std::filesystem::create_directory(output_dir);
auto base_path = std::filesystem::path{output_dir} / ("llamacpp-" + model_name + output.type_suffix);
// Save logits/embeddings to binary file.
{
std::filesystem::path filepath{base_path.string() + ".bin"};
std::ofstream file{filepath, std::ios::binary};
if (!file) {
throw std::runtime_error("failed to open binary output file: " + filepath.string());
}
file.write(reinterpret_cast<const char*>(output.data_ptr), output.data_size * sizeof(float));
LOG("Data saved to %s\n", filepath.c_str());
}
// Save logits/embeddings to text file.
{
std::filesystem::path filepath{base_path.string() + ".txt"};
std::ofstream file{filepath};
if (!file) {
throw std::runtime_error("failed to open text output file: " + filepath.string());
}
for (int i = 0; i < output.data_size; i++) {
file << i << ": " << output.data_ptr[i] << '\n';
}
LOG("Data saved to %s\n", filepath.c_str());
}
// Save prompt and tokens to text file.
{
std::filesystem::path filepath{base_path.string() + "-prompt.txt"};
std::ofstream file{filepath};
if (!file) {
throw std::runtime_error("failed to open prompt output file: " + filepath.string());
}
file << "prompt: " << output.prompt << '\n';
file << "n_tokens: " << output.tokens.size() << '\n';
file << "token ids: ";
for (size_t i = 0; i < output.tokens.size(); i++) {
file << output.tokens[i];
if (i + 1 < output.tokens.size()) {
file << ", ";
}
}
file << '\n';
LOG("Prompt saved to %s\n", filepath.c_str());
}
// Save token ids to binary file.
{
std::filesystem::path filepath{base_path.string() + "-tokens.bin"};
std::ofstream file{filepath, std::ios::binary};
if (!file) {
throw std::runtime_error("failed to open tokens binary file: " + filepath.string());
}
file.write(reinterpret_cast<const char*>(output.tokens.data()), output.tokens.size() * sizeof(llama_token));
LOG("Tokens saved to %s\n", filepath.c_str());
}
}
static void print_tokenized_prompt(llama_context * ctx, const std::vector<llama_token> & tokens, const std::string & prompt) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
LOG("Model add_bos: %s\n", llama_vocab_get_add_bos(vocab) ? "true" : "false");
LOG("Input prompt: \"%s\"\n", prompt.c_str());
LOG("Token ids (%zu):\n", tokens.size());
for (auto id : tokens) {
std::string piece(128, '\0');
int n = llama_token_to_piece(vocab, id, piece.data(), piece.size(), 0, true);
if (n < 0) {
LOG_ERR("failed to convert token %d to piece\n", id);
continue;
}
piece.resize(n);
LOG("%s(%d) ", piece.c_str(), id);
}
LOG("\n");
}
static bool run(llama_context * ctx, const common_params & params) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
const bool add_bos = llama_vocab_get_add_bos(vocab);
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);
if (tokens.empty()) {
LOG_ERR("%s : there are not input tokens to process - (try to provide a prompt with '-p')\n", __func__);
return false;
}
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
LOG_ERR("%s : failed to eval\n", __func__);
return false;
}
print_tokenized_prompt(ctx, tokens, params.prompt);
if (params.save_logits) {
output_data output {ctx, model, params};
std::filesystem::path model_path{params.model.path};
std::string model_name{model_path.stem().string()};
save_output_data(output, model_name, params.logits_output_dir);
}
return true;
}
int main(int argc, char ** argv) {
common_params params;
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DEBUG, print_usage)) {
return 1;
}
common_init();
llama_backend_init();
llama_numa_init(params.numa);
callback_data cb_data(params, params.tensor_filter);
auto llama_init = common_init_from_params(params);
auto * model = llama_init->model();
auto * ctx = llama_init->context();
if (model == nullptr || ctx == nullptr) {
LOG_ERR("%s : failed to init\n", __func__);
return 1;
}
{
LOG_INF("\n");
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
LOG_INF("\n");
}
if (!run(ctx, params)) {
return 1;
}
LOG("\n");
llama_perf_context_print(ctx);
llama_backend_free();
return 0;
}

View File

@ -553,6 +553,7 @@ int main(int argc, char ** argv) {
model_params.n_gpu_layers = params.n_gpu_layers;
model_params.devices = params.devices.data();
model_params.use_mmap = params.use_mmap;
model_params.use_direct_io = params.use_direct_io;
model_params.use_mlock = params.use_mlock;
model_params.check_tensors = params.check_tensors;

View File

@ -61,7 +61,7 @@ causal-run-converted-model:
@CONVERTED_MODEL="$(CONVERTED_MODEL)" ./scripts/causal/run-converted-model.sh
causal-verify-logits: causal-run-original-model causal-run-converted-model
@./scripts/causal/compare-logits.py
@MODEL_PATH="$(MODEL_PATH)" ./scripts/causal/compare-logits.py
@MODEL_PATH="$(MODEL_PATH)" ./scripts/utils/check-nmse.py -m ${MODEL_PATH}
causal-run-original-embeddings:
@ -138,16 +138,13 @@ embedding-run-original-model-st: embedding-run-original-model
embedding-run-converted-model:
@./scripts/embedding/run-converted-model.sh $(CONVERTED_EMBEDDING_MODEL) \
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") \
$(if $(USE_POOLING),--pooling)
embedding-run-converted-model-st: USE_POOLING=1
embedding-run-converted-model-st: embedding-run-converted-model
$(if $(EMBD_NORMALIZE),--embd-normalize "$(EMBD_NORMALIZE)")
embedding-verify-logits: embedding-run-original-model embedding-run-converted-model
@./scripts/embedding/compare-embeddings-logits.sh \
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")
embedding-verify-logits-st: embedding-run-original-model-st embedding-run-converted-model-st
embedding-verify-logits-st: embedding-run-original-model-st embedding-run-converted-model
@./scripts/embedding/compare-embeddings-logits.sh \
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")

View File

@ -198,14 +198,13 @@ model, and the other is a text file which allows for manual visual inspection.
#### Using SentenceTransformer with numbered layers
For models that have numbered SentenceTransformer layers (01_Pooling, 02_Dense,
03_Dense, 04_Normalize), use the `-st` targets to apply all these layers:
03_Dense, 04_Normalize), these will be applied automatically when running the
converted model but currently there is a separate target to run the original
version:
```console
# Run original model with SentenceTransformer (applies all numbered layers)
(venv) $ make embedding-run-original-model-st
# Run converted model with pooling enabled
(venv) $ make embedding-run-converted-model-st
```
This will use the SentenceTransformer library to load and run the model, which
@ -213,6 +212,17 @@ automatically applies all the numbered layers in the correct order. This is
particularly useful when comparing with models that should include these
additional transformation layers beyond just the base model output.
The type of normalization can be specified for the converted model but is not
strictly necessary as the verification uses cosine similarity and the magnitude
of the output vectors does not affect this. But the normalization type can be
specified as an argument to the target which might be useful for manual
inspection:
```console
(venv) $ make embedding-verify-logits-st EMBD_NORMALIZE=1
```
The original model will apply the normalization according to the normalization
layer specified in the modules.json configuration file.
### Model conversion
After updates have been made to [gguf-py](../../gguf-py) to add support for the
new model the model can be converted to GGUF format using the following command:

View File

@ -1,268 +0,0 @@
#include "llama.h"
#include "common.h"
#include <cstdio>
#include <cstring>
#include <string>
#include <vector>
#include <ctype.h>
#include <filesystem>
static void print_usage(int, char ** argv) {
printf("\nexample usage:\n");
printf("\n %s -m model.gguf [-ngl n_gpu_layers] -embd-mode [-pooling] [-embd-norm <norm>] [prompt]\n", argv[0]);
printf("\n");
printf(" -embd-norm: normalization type for pooled embeddings (default: 2)\n");
printf(" -1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm\n");
printf("\n");
}
int main(int argc, char ** argv) {
std::string model_path;
std::string prompt = "Hello, my name is";
int ngl = 0;
bool embedding_mode = false;
bool pooling_enabled = false;
int32_t embd_norm = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
{
int i = 1;
for (; i < argc; i++) {
if (strcmp(argv[i], "-m") == 0) {
if (i + 1 < argc) {
model_path = argv[++i];
} else {
print_usage(argc, argv);
return 1;
}
} else if (strcmp(argv[i], "-ngl") == 0) {
if (i + 1 < argc) {
try {
ngl = std::stoi(argv[++i]);
} catch (...) {
print_usage(argc, argv);
return 1;
}
} else {
print_usage(argc, argv);
return 1;
}
} else if (strcmp(argv[i], "-embd-mode") == 0) {
embedding_mode = true;
} else if (strcmp(argv[i], "-pooling") == 0) {
pooling_enabled = true;
} else if (strcmp(argv[i], "-embd-norm") == 0) {
if (i + 1 < argc) {
try {
embd_norm = std::stoi(argv[++i]);
} catch (...) {
print_usage(argc, argv);
return 1;
}
} else {
print_usage(argc, argv);
return 1;
}
} else {
// prompt starts here
break;
}
}
if (model_path.empty()) {
print_usage(argc, argv);
return 1;
}
if (i < argc) {
prompt = argv[i++];
for (; i < argc; i++) {
prompt += " ";
prompt += argv[i];
}
}
}
ggml_backend_load_all();
llama_model_params model_params = llama_model_default_params();
model_params.n_gpu_layers = ngl;
llama_model * model = llama_model_load_from_file(model_path.c_str(), model_params);
if (model == NULL) {
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
return 1;
}
// Extract basename from model_path
const char * basename = strrchr(model_path.c_str(), '/');
basename = (basename == NULL) ? model_path.c_str() : basename + 1;
char model_name[256];
strncpy(model_name, basename, 255);
model_name[255] = '\0';
char * dot = strrchr(model_name, '.');
if (dot != NULL && strcmp(dot, ".gguf") == 0) {
*dot = '\0';
}
printf("Model name: %s\n", model_name);
const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_prompt = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true);
std::vector<llama_token> prompt_tokens(n_prompt);
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < 0) {
fprintf(stderr, "%s: error: failed to tokenize the prompt\n", __func__);
return 1;
}
llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = n_prompt;
ctx_params.n_batch = n_prompt;
ctx_params.no_perf = false;
if (embedding_mode) {
ctx_params.embeddings = true;
ctx_params.pooling_type = pooling_enabled ? LLAMA_POOLING_TYPE_MEAN : LLAMA_POOLING_TYPE_NONE;
ctx_params.n_ubatch = ctx_params.n_batch;
}
llama_context * ctx = llama_init_from_model(model, ctx_params);
if (ctx == NULL) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
return 1;
}
printf("Input prompt: \"%s\"\n", prompt.c_str());
printf("Tokenized prompt (%d tokens): ", n_prompt);
for (auto id : prompt_tokens) {
char buf[128];
int n = llama_token_to_piece(vocab, id, buf, sizeof(buf), 0, true);
if (n < 0) {
fprintf(stderr, "%s: error: failed to convert token to piece\n", __func__);
return 1;
}
std::string s(buf, n);
printf("%s (%d)", s.c_str(), id);
}
printf("\n");
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
if (llama_decode(ctx, batch)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return 1;
}
float * data_ptr;
int data_size;
const char * type;
std::vector<float> embd_out;
if (embedding_mode) {
const int n_embd_out = llama_model_n_embd_out(model);
const int n_embd_count = pooling_enabled ? 1 : batch.n_tokens;
const int n_embeddings = n_embd_out * n_embd_count;
float * embeddings;
type = "-embeddings";
if (llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_NONE) {
embeddings = llama_get_embeddings_seq(ctx, 0);
embd_out.resize(n_embeddings);
printf("Normalizing embeddings using norm: %d\n", embd_norm);
common_embd_normalize(embeddings, embd_out.data(), n_embeddings, embd_norm);
embeddings = embd_out.data();
} else {
embeddings = llama_get_embeddings(ctx);
}
printf("Embedding dimension: %d\n", n_embd_out);
printf("\n");
// Print embeddings in the specified format
for (int j = 0; j < n_embd_count; j++) {
printf("embedding %d: ", j);
// Print first 3 values
for (int i = 0; i < 3 && i < n_embd_out; i++) {
printf("%9.6f ", embeddings[j * n_embd_out + i]);
}
printf(" ... ");
// Print last 3 values
for (int i = n_embd_out - 3; i < n_embd_out; i++) {
if (i >= 0) {
printf("%9.6f ", embeddings[j * n_embd_out + i]);
}
}
printf("\n");
}
printf("\n");
printf("Embeddings size: %d\n", n_embeddings);
data_ptr = embeddings;
data_size = n_embeddings;
} else {
float * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
const int n_logits = llama_vocab_n_tokens(vocab);
type = "";
printf("Vocab size: %d\n", n_logits);
data_ptr = logits;
data_size = n_logits;
}
std::filesystem::create_directory("data");
// Save data to binary file
char bin_filename[512];
snprintf(bin_filename, sizeof(bin_filename), "data/llamacpp-%s%s.bin", model_name, type);
printf("Saving data to %s\n", bin_filename);
FILE * f = fopen(bin_filename, "wb");
if (f == NULL) {
fprintf(stderr, "%s: error: failed to open binary output file\n", __func__);
return 1;
}
fwrite(data_ptr, sizeof(float), data_size, f);
fclose(f);
// Also save as text for debugging
char txt_filename[512];
snprintf(txt_filename, sizeof(txt_filename), "data/llamacpp-%s%s.txt", model_name, type);
f = fopen(txt_filename, "w");
if (f == NULL) {
fprintf(stderr, "%s: error: failed to open text output file\n", __func__);
return 1;
}
for (int i = 0; i < data_size; i++) {
fprintf(f, "%d: %.6f\n", i, data_ptr[i]);
}
fclose(f);
if (!embedding_mode) {
printf("First 10 logits: ");
for (int i = 0; i < 10 && i < data_size; i++) {
printf("%.6f ", data_ptr[i]);
}
printf("\n");
printf("Last 10 logits: ");
for (int i = data_size - 10; i < data_size; i++) {
if (i >= 0) printf("%.6f ", data_ptr[i]);
}
printf("\n\n");
}
printf("Data saved to %s\n", bin_filename);
printf("Data saved to %s\n", txt_filename);
llama_free(ctx);
llama_model_free(model);
return 0;
}

View File

@ -3,10 +3,11 @@
import sys
import numpy as np
from pathlib import Path
import os
# Add utils directory to path for direct script execution
sys.path.insert(0, str(Path(__file__).parent.parent / "utils"))
from common import get_model_name_from_env_path # type: ignore[import-not-found]
from common import get_model_name_from_env_path, compare_tokens, exit_with_warning # type: ignore[import-not-found]
def quick_logits_check(pytorch_file, llamacpp_file):
"""Lightweight sanity check before NMSE"""
@ -38,6 +39,7 @@ def quick_logits_check(pytorch_file, llamacpp_file):
return True
def main():
model_path = os.environ.get('MODEL_PATH')
model_name = get_model_name_from_env_path('MODEL_PATH')
data_dir = Path("data")
pytorch_file = data_dir / f"pytorch-{model_name}.bin"
@ -58,6 +60,12 @@ def main():
print("Checked all required files were found. Proceeding...\n")
# Verify tokens as they are a prerequisite for logits comparison.
print("🔍 Token Comparison Check")
print("=" * 40)
if not compare_tokens(f"pytorch-{model_name}", f"llamacpp-{llamacpp_model_name}"):
exit_with_warning("\n❌ Token mismatch detected", model_path)
print()
print("🔍 GGML Model Validation for model ", model_name)
print("=" * 40)
@ -73,8 +81,7 @@ def main():
print(" Ok to proceed with NMSE check...")
sys.exit(0)
else:
print(f"❌ NOK: Top 10 predictions don't match - generation will differ")
sys.exit(1)
exit_with_warning(f"❌ NOK: Top 10 predictions don't match - generation will differ", model_path)
if __name__ == "__main__":
main()

View File

@ -67,7 +67,7 @@ with torch.no_grad():
last_hidden_states = outputs.hidden_states[-1]
# Get embeddings for all tokens
token_embeddings = last_hidden_states[0].cpu().numpy() # Remove batch dimension
token_embeddings = last_hidden_states[0].float().cpu().numpy() # Remove batch dimension
print(f"Hidden states shape: {last_hidden_states.shape}")
print(f"Token embeddings shape: {token_embeddings.shape}")

View File

@ -13,6 +13,6 @@ if [ -z "$CONVERTED_MODEL" ]; then
exit 1
fi
cmake --build ../../build --target llama-logits -j8
cmake --build ../../build --target llama-debug -j8
../../build/bin/llama-logits -m $CONVERTED_MODEL -embd-mode "Hello world today"
../../build/bin/llama-debug -m $CONVERTED_MODEL --embedding -p "Hello world today" --save-logits

View File

@ -21,6 +21,6 @@ fi
echo $CONVERTED_MODEL
echo $MODEL_TESTING_PROMPT
cmake --build ../../build --target llama-logits -j8
cmake --build ../../build --target llama-debug -j8
../../build/bin/llama-logits -m "$CONVERTED_MODEL" "$MODEL_TESTING_PROMPT"
../../build/bin/llama-debug -m "$CONVERTED_MODEL" -p "$MODEL_TESTING_PROMPT" --save-logits

View File

@ -7,12 +7,11 @@ import importlib
import torch
import numpy as np
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, AutoConfig
# Add parent directory to path for imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from utils.common import debug_hook
from utils.common import debug_hook, save_output_data
def parse_arguments():
parser = argparse.ArgumentParser(description="Process model with specified path")
@ -126,6 +125,7 @@ def main():
device = next(model.parameters()).device
prompt = get_prompt(args)
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
token_ids = input_ids[0].cpu().tolist()
print(f"Input tokens: {input_ids}")
print(f"Input text: {repr(prompt)}")
@ -151,19 +151,6 @@ def main():
print(f"Last token logits shape: {last_logits.shape}")
print(f"Vocab size: {len(last_logits)}")
data_dir = Path("data")
data_dir.mkdir(exist_ok=True)
bin_filename = data_dir / f"pytorch-{model_name}.bin"
txt_filename = data_dir / f"pytorch-{model_name}.txt"
# Save to file for comparison
last_logits.astype(np.float32).tofile(bin_filename)
# Also save as text file for easy inspection
with open(txt_filename, "w") as f:
for i, logit in enumerate(last_logits):
f.write(f"{i}: {logit:.6f}\n")
# Print some sample logits for quick verification
print(f"First 10 logits: {last_logits[:10]}")
print(f"Last 10 logits: {last_logits[-10:]}")
@ -175,8 +162,7 @@ def main():
token = tokenizer.decode([idx])
print(f" Token {idx} ({repr(token)}): {last_logits[idx]:.6f}")
print(f"Saved bin logits to: {bin_filename}")
print(f"Saved txt logist to: {txt_filename}")
save_output_data(last_logits, token_ids, prompt, model_name)
if __name__ == "__main__":
main()

View File

@ -5,7 +5,7 @@ set -e
# Parse command line arguments
CONVERTED_MODEL=""
PROMPTS_FILE=""
USE_POOLING=""
EMBD_NORMALIZE="2"
while [[ $# -gt 0 ]]; do
case $1 in
@ -13,9 +13,9 @@ while [[ $# -gt 0 ]]; do
PROMPTS_FILE="$2"
shift 2
;;
--pooling)
USE_POOLING="1"
shift
--embd-normalize)
EMBD_NORMALIZE="$2"
shift 2
;;
*)
if [ -z "$CONVERTED_MODEL" ]; then
@ -50,10 +50,5 @@ fi
echo $CONVERTED_MODEL
cmake --build ../../build --target llama-logits -j8
# TODO: update logits.cpp to accept a --file/-f option for the prompt
if [ -n "$USE_POOLING" ]; then
../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode -pooling "$PROMPT"
else
../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode "$PROMPT"
fi
cmake --build ../../build --target llama-debug -j8
../../build/bin/llama-debug -m "$CONVERTED_MODEL" --embedding -p "$PROMPT" --save-logits --embd-normalize $EMBD_NORMALIZE

View File

@ -3,13 +3,15 @@
import argparse
import os
import sys
import numpy as np
import importlib
from pathlib import Path
from transformers import AutoTokenizer, AutoConfig, AutoModel
import torch
# Add parent directory to path for imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from utils.common import save_output_data
def parse_arguments():
parser = argparse.ArgumentParser(description='Run original embedding model')
@ -169,6 +171,7 @@ def main():
return_tensors="pt"
)
tokens = encoded['input_ids'][0]
token_ids = tokens.cpu().tolist()
token_strings = tokenizer.convert_ids_to_tokens(tokens)
for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)):
print(f"{token_id:6d} -> '{token_str}'")
@ -185,6 +188,7 @@ def main():
)
tokens = encoded['input_ids'][0]
token_ids = tokens.cpu().tolist()
token_strings = tokenizer.convert_ids_to_tokens(tokens)
for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)):
print(f"{token_id:6d} -> '{token_str}'")
@ -228,24 +232,11 @@ def main():
print()
data_dir = Path("data")
data_dir.mkdir(exist_ok=True)
bin_filename = data_dir / f"pytorch-{model_name}-embeddings.bin"
txt_filename = data_dir / f"pytorch-{model_name}-embeddings.txt"
flattened_embeddings = all_embeddings.flatten()
flattened_embeddings.astype(np.float32).tofile(bin_filename)
with open(txt_filename, "w") as f:
idx = 0
for j in range(n_embd_count):
for value in all_embeddings[j]:
f.write(f"{idx}: {value:.6f}\n")
idx += 1
print(f"Total values: {len(flattened_embeddings)} ({n_embd_count} embeddings × {n_embd} dimensions)")
print("")
print(f"Saved bin embeddings to: {bin_filename}")
print(f"Saved txt embeddings to: {txt_filename}")
save_output_data(flattened_embeddings, token_ids, prompt_text, model_name, type_suffix="-embeddings")
if __name__ == "__main__":

View File

@ -3,6 +3,11 @@
import os
import sys
import torch
import transformers
import json
import textwrap
import numpy as np
from pathlib import Path
def get_model_name_from_env_path(env_path_name):
@ -148,3 +153,147 @@ def setup_rope_debug(model_module_path: str, function_name: str = "apply_rotary_
# Patch it
setattr(module, function_name, debug_rope)
print(f"RoPE debug patching applied to {model_module_path}.{function_name}")
def save_output_data(data, tokens, prompt, model_name, type_suffix="", output_dir="data"):
"""
Save output data (logits/embeddings), tokens, and prompt to files.
Args:
data: numpy array of floats (logits or embeddings)
tokens: list or array of token IDs
prompt: string containing the input prompt
model_name: name of the model
type_suffix: optional suffix like "-embeddings" (default: "")
output_dir: directory to save files (default: "data")
Creates the following files in output_dir:
- pytorch-{model_name}{type_suffix}.bin
- pytorch-{model_name}{type_suffix}.txt
- pytorch-{model_name}{type_suffix}-prompt.txt
- pytorch-{model_name}{type_suffix}-tokens.bin
"""
data_dir = Path(output_dir)
data_dir.mkdir(exist_ok=True)
base_path = data_dir / f"pytorch-{model_name}{type_suffix}"
# Convert and flatten logits/embeddings
data = data.cpu().numpy() if isinstance(data, torch.Tensor) else np.asarray(data)
data = data.flatten() if data.ndim > 1 else data
# Save logits/embedding files
data.astype(np.float32).tofile(f"{base_path}.bin")
print(f"Data saved to {base_path}.bin")
with open(f"{base_path}.txt", "w") as f:
f.writelines(f"{i}: {value:.6f}\n" for i, value in enumerate(data))
print(f"Data saved to {base_path}.txt")
# Convert and flatten tokens
tokens = tokens.cpu().numpy() if isinstance(tokens, torch.Tensor) else np.asarray(tokens)
tokens = tokens.flatten() if tokens.ndim > 1 else tokens
# Save token binary file
tokens.astype(np.int32).tofile(f"{base_path}-tokens.bin")
print(f"Tokens saved to {base_path}-tokens.bin")
# Save prompt file
with open(f"{base_path}-prompt.txt", "w") as f:
f.write(f"prompt: {prompt}\n")
f.write(f"n_tokens: {len(tokens)}\n")
f.write(f"token ids: {', '.join(str(int(tid)) for tid in tokens)}\n")
print(f"Prompt saved to {base_path}-prompt.txt")
def compare_tokens(original, converted, type_suffix="", output_dir="data"):
data_dir = Path(output_dir)
# Read tokens from both models
tokens1_file = data_dir / f"{original}{type_suffix}-tokens.bin"
tokens2_file = data_dir / f"{converted}{type_suffix}-tokens.bin"
if not tokens1_file.exists():
print(f"Error: Token file not found: {tokens1_file}")
return False
if not tokens2_file.exists():
print(f"Error: Token file not found: {tokens2_file}")
return False
tokens1 = np.fromfile(tokens1_file, dtype=np.int32)
tokens2 = np.fromfile(tokens2_file, dtype=np.int32)
print(f"\nComparing tokens between:")
print(f" Original : {original} ({len(tokens1)} tokens)")
print(f" Converted: {converted} ({len(tokens2)} tokens)")
if len(tokens1) != len(tokens2):
print(f"\n❌ Token count mismatch: {len(tokens1)} vs {len(tokens2)}")
return False
if np.array_equal(tokens1, tokens2):
print(f"\n✅ All {len(tokens1)} tokens match!")
return True
mismatches = np.where(tokens1 != tokens2)[0]
print(f"\n❌ Found {len(mismatches)} mismatched tokens:")
num_to_show = min(len(mismatches), 10)
for idx in mismatches[:num_to_show]:
print(f" Position {idx}: {tokens1[idx]} vs {tokens2[idx]}")
if len(mismatches) > num_to_show:
print(f" ... and {len(mismatches) - num_to_show} more mismatches")
return False
def show_version_warning(current_version, model_version):
if not model_version:
return False
try:
from packaging.version import parse, InvalidVersion
try:
return parse(current_version) < parse(model_version)
except InvalidVersion:
return current_version != model_version
except ImportError:
return current_version != model_version
def get_model_transformers_version(model_path):
if not model_path:
return None
config_path = Path(model_path) / "config.json"
if not config_path.is_file():
return None
try:
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
return config.get("transformers_version")
except (IOError, json.JSONDecodeError) as e:
print(f"Warning: Could not read or parse {config_path}: {e}", file=sys.stderr)
return None
def exit_with_warning(message, model_path):
print(message)
if model_path and transformers is not None:
model_transformers_version = get_model_transformers_version(model_path)
transformers_version = transformers.__version__
if show_version_warning(transformers_version, model_transformers_version):
warning_message = f"""
=====================================================================
Verification failure might be due to a transformers version mismatch:
Current transformers version: {transformers_version}
Model's required version : {model_transformers_version}
Consider installing the version specified by the model's config:
pip install transformers=={model_transformers_version}
=====================================================================
"""
print(textwrap.dedent(warning_message))
sys.exit(1)

View File

@ -0,0 +1,76 @@
#!/usr/bin/env python3
import argparse
import sys
from common import compare_tokens # type: ignore
def parse_arguments():
parser = argparse.ArgumentParser(
description='Compare tokens between two models',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
%(prog)s pytorch-gemma-3-270m-it llamacpp-gemma-3-270m-it-bf16
"""
)
parser.add_argument(
'original',
help='Original model name'
)
parser.add_argument(
'converted',
help='Converted model name'
)
parser.add_argument(
'-s', '--suffix',
default='',
help='Type suffix (e.g., "-embeddings")'
)
parser.add_argument(
'-d', '--data-dir',
default='data',
help='Directory containing token files (default: data)'
)
parser.add_argument(
'-v', '--verbose',
action='store_true',
help='Print prompts from both models'
)
return parser.parse_args()
def main():
args = parse_arguments()
if args.verbose:
from pathlib import Path
data_dir = Path(args.data_dir)
prompt1_file = data_dir / f"{args.original}{args.suffix}-prompt.txt"
prompt2_file = data_dir / f"{args.converted}{args.suffix}-prompt.txt"
if prompt1_file.exists():
print(f"\nOriginal model prompt ({args.original}):")
print(f" {prompt1_file.read_text().strip()}")
if prompt2_file.exists():
print(f"\nConverted model prompt ({args.converted}):")
print(f" {prompt2_file.read_text().strip()}")
print()
result = compare_tokens(
args.original,
args.converted,
type_suffix=args.suffix,
output_dir=args.data_dir
)
# Enable the script to be used in shell scripts so that they can check
# the exit code for success/failure.
sys.exit(0 if result else 1)
if __name__ == "__main__":
main()

View File

@ -4,8 +4,10 @@ import numpy as np
import argparse
import os
import importlib
from pathlib import Path
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, AutoModel
from common import compare_tokens, exit_with_warning # type: ignore[import-not-found]
unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME')
@ -157,9 +159,24 @@ def main():
else:
prompt = args.prompt
python_emb_path = Path(args.python_embeddings)
cpp_emb_path = Path(args.cpp_embeddings)
# Extract base names (e.g., "pytorch-model-name-embeddings.bin" -> "pytorch-model-name")
python_model_name = python_emb_path.stem.replace("-embeddings", "")
cpp_model_name = cpp_emb_path.stem.replace("-embeddings", "")
print("Semantic Similarity Test Between Python and llama.cpp Embedding Models")
print("=" * 70)
# First verify tokens match before comparing embeddings
print("\n🔍 Token Comparison Check")
print("=" * 70)
data_dir = python_emb_path.parent
if not compare_tokens(python_model_name, cpp_model_name, type_suffix="-embeddings", output_dir=str(data_dir)):
exit_with_warning("\n❌ Token mismatch detected", args.model_path)
print()
# Single prompt detailed comparison
print(f"\nTesting with prompt: '{prompt}'")
@ -219,7 +236,7 @@ def main():
elif avg_cross_sim > 0.70:
print("⚠️ FAIR: Models have some differences")
else:
print("❌ POOR: Models are significantly different")
exit_with_warning("❌ POOR: Models are significantly different", args.model_path)
if __name__ == "__main__":
main()

View File

@ -234,6 +234,11 @@
#if UINTPTR_MAX == 0xFFFFFFFF
#define GGML_MEM_ALIGN 4
#elif defined(__EMSCRIPTEN__)
// emscripten uses max_align_t == 8, so we need GGML_MEM_ALIGN == 8 for 64-bit wasm.
// (for 32-bit wasm, the first conditional is true and GGML_MEM_ALIGN stays 4.)
// ref: https://github.com/ggml-org/llama.cpp/pull/18628
#define GGML_MEM_ALIGN 8
#else
#define GGML_MEM_ALIGN 16
#endif

View File

@ -144,7 +144,7 @@ extern "C" {
// device description: short informative description of the device, could be the model name
const char * (*get_description)(ggml_backend_dev_t dev);
// device memory in bytes
// device memory in bytes: 0 bytes to indicate no memory to report
void (*get_memory)(ggml_backend_dev_t dev, size_t * free, size_t * total);
// device type

View File

@ -2541,27 +2541,6 @@ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
}
/**
* @brief Determines if a tensor operation should be offloaded to the CANN
* backend.
*
* This function checks if a given tensor operation should be offloaded to the
* CANN backend based on the operation type and the size of the tensor. It
* returns true if the second dimension (ne[1]) of the tensor is greater than or
* equal to the minimum batch size and the operation is not GGML_OP_GET_ROWS.
*
* @param backend Pointer to the CANN backend.
* @param op Pointer to the tensor operation to check.
* @return bool Returns true if the operation should be offloaded, otherwise
* false.
*/
static bool ggml_backend_cann_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
const int min_batch_size = 32;
GGML_UNUSED(dev);
return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
}
/**
* @brief Records an event on the CANN backend stream.
*
@ -2637,6 +2616,7 @@ struct ggml_backend_cann_device_context {
int device;
std::string name;
std::string description;
int op_offload_min_batch_size;
};
static const char * ggml_backend_cann_device_get_name(ggml_backend_dev_t dev) {
@ -2713,6 +2693,26 @@ static ggml_backend_buffer_type_t ggml_backend_cann_device_get_host_buffer_type(
return ggml_backend_cann_host_buffer_type();
}
/**
* @brief Determines if a tensor operation should be offloaded to the CANN
* backend.
*
* This function checks if a given tensor operation should be offloaded to the
* CANN backend based on the operation type and the size of the tensor. It
* returns true if the second dimension (ne[1]) of the tensor is greater than or
* equal to the minimum batch size and the operation is not GGML_OP_GET_ROWS.
*
* @param backend Pointer to the CANN backend.
* @param op Pointer to the tensor operation to check.
* @return bool Returns true if the operation should be offloaded, otherwise
* false.
*/
static bool ggml_backend_cann_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
return op->ne[1] >= dev_ctx->op_offload_min_batch_size && op->op != GGML_OP_GET_ROWS;
}
/**
* @brief Creates a new event for the CANN backend device.
*
@ -2829,12 +2829,14 @@ ggml_backend_reg_t ggml_backend_cann_reg() {
if (!initialized) {
aclInit(nullptr);
ggml_backend_cann_reg_context * ctx = new ggml_backend_cann_reg_context;
const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
for (int i = 0; i < ggml_cann_info().device_count; i++) {
ggml_backend_cann_device_context * dev_ctx = new ggml_backend_cann_device_context();
dev_ctx->description = aclrtGetSocName();
dev_ctx->device = i;
dev_ctx->name = GGML_CANN_NAME + std::to_string(i);
dev_ctx->op_offload_min_batch_size = min_batch_size;
ggml_cann_set_device(i);
ggml_backend_dev_t dev = new ggml_backend_device{ /* .iface = */ ggml_backend_cann_device_interface,
/* .reg = */ &reg,

View File

@ -47,7 +47,10 @@ if (CUDAToolkit_FOUND)
# check Modules/Internal/CMakeCUDAArchitecturesValidate.cmake in the CMake git repository instead.
# However, the architectures 120a-real and 121a-real should work with basically any CMake version and
# until the release of e.g. Rubin there is no benefit to shipping virtual architectures for Blackwell.
list(APPEND CMAKE_CUDA_ARCHITECTURES 120a-real 121a-real)
list(APPEND CMAKE_CUDA_ARCHITECTURES 120a-real)
endif()
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.9")
list(APPEND CMAKE_CUDA_ARCHITECTURES 121a-real)
endif()
endif()
endif()

View File

@ -4122,6 +4122,7 @@ struct ggml_backend_cuda_device_context {
std::string name;
std::string description;
std::string pci_bus_id;
int op_offload_min_batch_size;
};
static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
@ -4676,11 +4677,9 @@ static int64_t get_op_batch_size(const ggml_tensor * op) {
}
static bool ggml_backend_cuda_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
const int min_batch_size = 32;
ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context;
return get_op_batch_size(op) >= min_batch_size;
GGML_UNUSED(dev);
return get_op_batch_size(op) >= dev_ctx->op_offload_min_batch_size;
}
static ggml_backend_event_t ggml_backend_cuda_device_event_new(ggml_backend_dev_t dev) {
@ -4848,6 +4847,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
std::lock_guard<std::mutex> lock(mutex);
if (!initialized) {
ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context;
const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
for (int i = 0; i < ggml_cuda_info().device_count; i++) {
ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context;
@ -4861,6 +4861,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
char pci_bus_id[16] = {};
snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID);
dev_ctx->pci_bus_id = pci_bus_id;
dev_ctx->op_offload_min_batch_size = min_batch_size;
ggml_backend_dev_t dev = new ggml_backend_device {
/* .iface = */ ggml_backend_cuda_device_interface,

View File

@ -219,6 +219,8 @@ struct ggml_metal_device_props {
bool use_shared_buffers;
bool supports_gpu_family_apple7;
int op_offload_min_batch_size;
};
ggml_metal_device_t ggml_metal_device_init(void);

View File

@ -782,6 +782,8 @@ ggml_metal_device_t ggml_metal_device_init(void) {
dev->props.supports_gpu_family_apple7 = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7];
dev->props.op_offload_min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
dev->props.max_buffer_size = dev->mtl_device.maxBufferLength;
dev->props.max_working_set_size = dev->mtl_device.recommendedMaxWorkingSetSize;
dev->props.max_theadgroup_memory_size = dev->mtl_device.maxThreadgroupMemoryLength;

View File

@ -625,14 +625,11 @@ static int64_t get_op_batch_size(const ggml_tensor * op) {
}
static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
const int min_batch_size = 32;
ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
return (op->op == GGML_OP_MUL_MAT ||
op->op == GGML_OP_MUL_MAT_ID) &&
get_op_batch_size(op) >= min_batch_size;
GGML_UNUSED(dev);
GGML_UNUSED(op);
get_op_batch_size(op) >= ggml_metal_device_get_props(ctx_dev)->op_offload_min_batch_size;
}
static ggml_backend_device_i ggml_backend_metal_device_i = {

View File

@ -9148,6 +9148,7 @@ typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t;
template [[host_name("kernel_mul_mm_id_map0_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>;
template [[host_name("kernel_mul_mm_id_map0_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>;
template [[host_name("kernel_mul_mm_id_map0_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
template [[host_name("kernel_mul_mm_id_map0_ne20_5" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<5>;
template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;

View File

@ -57,6 +57,7 @@ set(GGML_OPENCL_KERNELS
add
add_id
argsort
fill
clamp
cpy
cvt

View File

@ -489,6 +489,7 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_gelu_quick, kernel_gelu_quick_4;
cl_kernel kernel_relu;
cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16;
cl_kernel kernel_fill;
cl_kernel kernel_clamp;
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_swiglu_oai, kernel_geglu_erf, kernel_geglu_quick,
kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
@ -787,6 +788,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
GGML_LOG_CONT(".");
}
// fill
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "fill.cl.h"
};
#else
const std::string kernel_src = read_file("fill.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_fill = clCreateKernel(prog, "kernel_fill_f32", &err), err));
GGML_LOG_CONT(".");
CL_CHECK(clReleaseProgram(prog));
}
// clamp
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@ -3104,6 +3123,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
default:
return false;
}
case GGML_OP_FILL:
return op->type == GGML_TYPE_F32 && ggml_is_contiguous(op);
case GGML_OP_CLAMP:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SOFT_MAX:
@ -4266,8 +4287,8 @@ static const char * ggml_backend_opencl_device_get_description(ggml_backend_dev_
}
static void ggml_backend_opencl_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
*free = 1;
*total = 1;
*free = 0;
*total = 0;
GGML_UNUSED(dev);
}
@ -5860,6 +5881,36 @@ static void ggml_cl_sigmoid(ggml_backend_t backend, const ggml_tensor * src0, co
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
}
static void ggml_cl_fill(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(dst);
GGML_ASSERT(dst->extra);
UNUSED(src0);
UNUSED(src1);
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offsetd = extrad->offset + dst->view_offs;
float v = 0.0f;
memcpy(&v, ((int32_t *) dst->op_params), sizeof(float));
const int64_t n = ggml_nelements(dst);
cl_kernel kernel = backend_ctx->kernel_fill;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(float), &v));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(float), &n));
size_t local_work_size[1] = { 256 };
size_t global_work_size[1] = { ((size_t)n + local_work_size[0] - 1) / local_work_size[0] * local_work_size[0] };
backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size, dst);
}
static void ggml_cl_clamp(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
@ -9595,6 +9646,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
}
func = ggml_cl_glu;
break;
case GGML_OP_FILL:
if (!any_on_device) {
return false;
}
func = ggml_cl_fill;
break;
case GGML_OP_CLAMP:
if (!any_on_device) {
return false;

View File

@ -0,0 +1,17 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
//------------------------------------------------------------------------------
// fill
//------------------------------------------------------------------------------
__kernel void kernel_fill_f32(
__global float *dst,
ulong offsetd,
float v,
int n
) {
dst = (global float*)((global char*)dst + offsetd);
if(get_global_id(0) < n){
dst[get_global_id(0)] = v;
}
}

View File

@ -4286,6 +4286,7 @@ struct ggml_backend_sycl_device_context {
int device;
std::string name;
std::string description;
int op_offload_min_batch_size;
};
static const char * ggml_backend_sycl_device_get_name(ggml_backend_dev_t dev) {
@ -4674,9 +4675,8 @@ static int64_t get_op_batch_size(const ggml_tensor * op) {
}
static bool ggml_backend_sycl_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
const int min_batch_size = 32;
return get_op_batch_size(op) >= min_batch_size;
GGML_UNUSED(dev);
ggml_backend_sycl_device_context * sycl_ctx = (ggml_backend_sycl_device_context *)dev->context;
return get_op_batch_size(op) >= sycl_ctx->op_offload_min_batch_size;
}
static ggml_backend_event_t
@ -4799,6 +4799,7 @@ ggml_backend_reg_t ggml_backend_sycl_reg() {
std::lock_guard<std::mutex> lock(mutex);
if (!initialized) {
ggml_backend_sycl_reg_context * ctx = new ggml_backend_sycl_reg_context;
const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
for (int i = 0; i < ggml_sycl_info().device_count; i++) {
ggml_backend_sycl_device_context * dev_ctx = new ggml_backend_sycl_device_context;
@ -4812,6 +4813,7 @@ ggml_backend_reg_t ggml_backend_sycl_reg() {
prop, dpct::dev_mgr::instance().get_device(i))));
dev_ctx->description = prop.get_name();
dev_ctx->op_offload_min_batch_size = min_batch_size;
ggml_backend_dev_t dev = new ggml_backend_device {
/* .iface = */ ggml_backend_sycl_device_interface,

View File

@ -570,6 +570,7 @@ struct vk_device_struct {
bool uma;
bool prefer_host_memory;
bool float_controls_rte_fp16;
bool subgroup_basic;
bool subgroup_arithmetic;
bool subgroup_shuffle;
bool subgroup_ballot;
@ -1504,6 +1505,11 @@ template <> void init_pushconst_fastdiv(vk_op_sum_rows_push_constants &p) {
init_fastdiv_values(p.ne01, p.ne0_1mp, p.ne0_1L);
}
struct vk_quantize_q8_1_push_constants {
uint32_t ne;
uint32_t num_blocks;
};
// Allow pre-recording command buffers
struct vk_staging_memcpy {
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@ -2996,6 +3002,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
m_warptile_mmqid = m_warptile_mmqid_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
} else if (device->vendor_id == VK_VENDOR_ID_INTEL && device->coopmat_support && device->architecture == INTEL_XE2) {
// Xe2/Xe3 with coopmat enabled - warptile performance tuning
l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
l_warptile_mmq = { 512, 128, 128, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
}
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
@ -3336,12 +3346,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
GGML_ASSERT(device->subgroup_ballot);
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id);
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
if (device->coopmat_bf16_support) {
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
}
#endif
@ -3449,9 +3459,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
#endif
if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) {
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
@ -3493,9 +3503,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
#endif
} else {
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
@ -3610,9 +3620,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
#endif
if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) {
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_subgroup_f16, , wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_subgroup_f16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
@ -3636,9 +3646,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
} else {
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
@ -3678,6 +3688,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
m_wg_denoms = { 64, 64, 1 };
s_wg_denoms = { 32, 32, 1 };
if (device->vendor_id == VK_VENDOR_ID_INTEL && device->architecture == INTEL_XE2) {
// Xe2/Xe3 - bf16 warptile performance tuning
l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, 4, 4, 1, subgroup_size_8 };
}
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
}
@ -3831,22 +3846,22 @@ static void ggml_vk_load_shaders(vk_device& device) {
const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size;
const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_q8_1_f32", arr_dmmv_id_q4_0_q8_1_f32_len[reduc], arr_dmmv_id_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_q8_1_f32", arr_dmmv_id_q4_1_q8_1_f32_len[reduc], arr_dmmv_id_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_q8_1_f32", arr_dmmv_id_q5_0_q8_1_f32_len[reduc], arr_dmmv_id_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_q8_1_f32", arr_dmmv_id_q5_1_q8_1_f32_len[reduc], arr_dmmv_id_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_q8_1_f32", arr_dmmv_id_q8_0_q8_1_f32_len[reduc], arr_dmmv_id_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_q8_1_f32", arr_dmmv_id_q4_0_q8_1_f32_len[reduc], arr_dmmv_id_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_q8_1_f32", arr_dmmv_id_q4_1_q8_1_f32_len[reduc], arr_dmmv_id_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_q8_1_f32", arr_dmmv_id_q5_0_q8_1_f32_len[reduc], arr_dmmv_id_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_q8_1_f32", arr_dmmv_id_q5_1_q8_1_f32_len[reduc], arr_dmmv_id_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_q8_1_f32", arr_dmmv_id_q8_0_q8_1_f32_len[reduc], arr_dmmv_id_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_q8_1_f32", arr_dmmv_id_mxfp4_q8_1_f32_len[reduc], arr_dmmv_id_mxfp4_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_q8_1_f32", arr_dmmv_id_mxfp4_q8_1_f32_len[reduc], arr_dmmv_id_mxfp4_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_q8_1_f32", arr_dmmv_id_q2_k_q8_1_f32_len[reduc], arr_dmmv_id_q2_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_q8_1_f32", arr_dmmv_id_q3_k_q8_1_f32_len[reduc], arr_dmmv_id_q3_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_q8_1_f32", arr_dmmv_id_q4_k_q8_1_f32_len[reduc], arr_dmmv_id_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_q8_1_f32", arr_dmmv_id_q5_k_q8_1_f32_len[reduc], arr_dmmv_id_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_q8_1_f32", arr_dmmv_id_q6_k_q8_1_f32_len[reduc], arr_dmmv_id_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_q8_1_f32", arr_dmmv_id_q2_k_q8_1_f32_len[reduc], arr_dmmv_id_q2_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_q8_1_f32", arr_dmmv_id_q3_k_q8_1_f32_len[reduc], arr_dmmv_id_q3_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_q8_1_f32", arr_dmmv_id_q4_k_q8_1_f32_len[reduc], arr_dmmv_id_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_q8_1_f32", arr_dmmv_id_q5_k_q8_1_f32_len[reduc], arr_dmmv_id_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_q8_1_f32", arr_dmmv_id_q6_k_q8_1_f32_len[reduc], arr_dmmv_id_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_q8_1_f32", arr_dmmv_id_iq1_s_q8_1_f32_len[reduc], arr_dmmv_id_iq1_s_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_q8_1_f32", arr_dmmv_id_iq1_m_q8_1_f32_len[reduc], arr_dmmv_id_iq1_m_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_q8_1_f32", arr_dmmv_id_iq1_s_q8_1_f32_len[reduc], arr_dmmv_id_iq1_s_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_q8_1_f32", arr_dmmv_id_iq1_m_q8_1_f32_len[reduc], arr_dmmv_id_iq1_m_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int);
}
#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
}
@ -3934,9 +3949,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, 5 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
if (device->subgroup_clustered && device->subgroup_require_full_support) {
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true);
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true);
} else {
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_len, quantize_q8_1_x4_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_len, quantize_q8_1_x4_data, "main", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
}
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
@ -4144,9 +4159,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f32, "add1_f16_f32", add1_f16_f32_len, add1_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_add1_f32_f32, "add1_f32_f32", add1_f32_f32_len, add1_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_arange_f32, "arange_f32", arange_f32_len, arange_f32_data, "main", 1, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_arange_f32, "arange_f32", arange_f32_len, arange_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_fill_f32, "fill_f32", fill_f32_len, fill_f32_data, "main", 1, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_fill_f32, "fill_f32", fill_f32_len, fill_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
#define CREATE_GLU(name) \
if (device->float_controls_rte_fp16) { \
@ -4292,8 +4307,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size}, 1, true, true);
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size}, 1, true, true);
} else {
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
@ -4629,6 +4644,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
}
device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16;
device->subgroup_basic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBasic);
device->subgroup_arithmetic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
#ifdef __APPLE__
@ -5061,11 +5078,23 @@ static vk_device ggml_vk_get_device(size_t idx) {
switch (device->vendor_id) {
#ifndef GGML_VULKAN_RUN_TESTS
case VK_VENDOR_ID_AMD:
device->mul_mat_l[i] = false;
device->mul_mat_m[i] = true;
device->mul_mat_s[i] = true;
device->mul_mat_id_l[i] = false;
device->mul_mat_id_m[i] = true;
device->mul_mat_id_s[i] = true;
break;
case VK_VENDOR_ID_INTEL:
device->mul_mat_l[i] = false;
if (!device->coopmat_support || device->architecture != INTEL_XE2) {
device->mul_mat_l[i] = false;
device->mul_mat_id_l[i] = false;
} else {
device->mul_mat_l[i] = true; // if coopmat & XE2+, allow large matmul warptile config for Intel
device->mul_mat_id_l[i] = true;
}
device->mul_mat_m[i] = true;
device->mul_mat_s[i] = true;
device->mul_mat_id_l[i] = false;
device->mul_mat_id_m[i] = true;
device->mul_mat_id_s[i] = true;
break;
@ -6076,6 +6105,7 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context&
GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());
GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);
GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size());
GGML_ASSERT(pipeline->push_constant_size == push_constant_size(push_constants));
vk::DescriptorSet& descriptor_set = ctx->descriptor_sets[ctx->descriptor_set_idx++];
vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() };
@ -6858,7 +6888,12 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub
const uint64_t max_elements = std::min<uint64_t>(uint64_t{ctx->device->properties.limits.maxComputeWorkGroupCount[0]} * pipeline->wg_denoms[0], std::numeric_limits<uint32_t>::max());
const uint32_t elements = std::min(ne, static_cast<uint32_t>(max_elements));
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array<uint32_t, 2>{ ne, num_blocks }, { elements, 1, 1 });
const vk_quantize_q8_1_push_constants pc = {
ne,
num_blocks,
};
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, { elements, 1, 1 });
ggml_vk_sync_buffers(ctx, subctx);
}
@ -9849,8 +9884,9 @@ static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx,
std::array<uint32_t, 3> elements;
const int splitH = 16;
const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, splitH);
const uint32_t d_state = src0->ne[0];
uint32_t num_subgroups = d_state / ctx->device->subgroup_size;
const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, num_subgroups);
const uint32_t num_workgroups_y = n_seq;
elements = { num_workgroups_x, num_workgroups_y, 1 };
@ -14228,6 +14264,7 @@ struct ggml_backend_vk_device_context {
std::string description;
bool is_integrated_gpu;
std::string pci_bus_id;
int op_offload_min_batch_size;
};
static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) {
@ -14284,6 +14321,19 @@ static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const
}
static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
const vk_device& device = ggml_vk_get_device(ctx->device);
// reject any tensors larger than the max buffer size
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (op->src[i] && ggml_nbytes(op->src[i]) > device->max_buffer_size) {
return false;
}
}
if (ggml_nbytes(op) > device->max_buffer_size) {
return false;
}
switch (op->op) {
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
@ -14332,8 +14382,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_MUL_MAT_ID:
{
ggml_type src0_type = op->src[0]->type;
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
const vk_device& device = ggml_vk_get_device(ctx->device);
if (op->op == GGML_OP_MUL_MAT_ID) {
if (!device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
// If there's not enough shared memory for row_ids and the result tile, fallback to CPU
@ -14394,8 +14442,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
}
case GGML_OP_FLASH_ATTN_EXT:
{
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
auto device = ggml_vk_get_device(ctx->device);
bool coopmat2 = device->coopmat2;
uint32_t HSK = op->src[1]->ne[0];
uint32_t HSV = op->src[2]->ne[0];
@ -14617,8 +14663,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
return false;
}
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
auto device = ggml_vk_get_device(ctx->device);
// pipeline_argsort_large_f32 requires vulkan memory model.
if (device->vulkan_memory_model) {
return true;
@ -14631,8 +14675,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
return false;
}
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
auto device = ggml_vk_get_device(ctx->device);
// We could potentially support larger, using argsort to sort the
// whole thing. Not clear if this is needed.
uint32_t min_pipeline = (uint32_t)log2f(float(op->ne[0])) + 1;
@ -14679,8 +14721,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_CUMSUM:
{
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
auto device = ggml_vk_get_device(ctx->device);
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
}
@ -14688,9 +14728,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
}
case GGML_OP_SOLVE_TRI:
{
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
const vk_device& device = ggml_vk_get_device(ctx->device);
if (op->type != GGML_TYPE_F32 || op->src[0]->type != GGML_TYPE_F32) {
return false;
}
@ -14755,14 +14792,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return false;
}
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
const vk_device& device = ggml_vk_get_device(ctx->device);
size_t shmem_size = d_state * sizeof(float);
const uint32_t SPLIT_H = 16;
if (shmem_size > device->properties.limits.maxComputeSharedMemorySize) {
return false;
}
size_t stateC_size = SPLIT_H * d_state * sizeof(float);
if (stateC_size > device->properties.limits.maxComputeSharedMemorySize) {
if (!device->subgroup_basic) {
return false;
}
@ -14802,12 +14838,10 @@ static bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_ba
}
static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
const int min_batch_size = 32;
ggml_backend_vk_device_context * dev_ctx = (ggml_backend_vk_device_context *)dev->context;
return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
(op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
UNUSED(dev);
return (op->ne[1] >= dev_ctx->op_offload_min_batch_size && op->op != GGML_OP_GET_ROWS) ||
(op->ne[2] >= dev_ctx->op_offload_min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
}
static ggml_backend_event_t ggml_backend_vk_device_event_new(ggml_backend_dev_t dev) {
@ -14933,6 +14967,7 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
if (!initialized) {
const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) {
ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context;
char desc[256];
@ -14942,6 +14977,7 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
ctx->description = desc;
ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu;
ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i);
ctx->op_offload_min_batch_size = min_batch_size;
devices.push_back(new ggml_backend_device {
/* .iface = */ ggml_backend_vk_device_i,
/* .reg = */ reg,

View File

@ -462,7 +462,8 @@ vec2 get_dm(uint ib, uint a_offset) {
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m));
const vec2 dm = vec2(data_a_packed32[a_offset + ib].dm);
return dm;
}
#endif

View File

@ -47,7 +47,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
#endif
#elif defined(DATA_A_Q4_0)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
const uint ib = idx / 4;
const uint iqs = idx & 0x03;
@ -63,16 +63,15 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw);
#elif defined(DATA_A_Q4_1)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
const uint ib = idx / 4;
const uint iqs = idx & 0x03;
const float d = float(data_a_packed16[ib].d);
const float m = float(data_a_packed16[ib].m);
const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m;
const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m;
const vec2 dm = vec2(data_a_packed32[ib].dm);
const uint vui = data_a_packed32[ib].qs[iqs];
const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * dm.x + dm.y;
const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * dm.x + dm.y;
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy);
buf_a[buf_idx + 1 ] = FLOAT_TYPE_VEC2(v0.zw);
@ -80,7 +79,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
buf_a[buf_idx + 9 ] = FLOAT_TYPE_VEC2(v1.zw);
#elif defined(DATA_A_Q5_0)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
const uint ib = idx / 8;
const uint iqs = idx & 0x07;
@ -97,22 +96,26 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw);
#elif defined(DATA_A_Q5_1)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
const uint ib = idx / 8;
const uint iqs = idx & 0x07;
const uint ib = idx / 4;
const uint iqs = idx & 0x03;
const float d = float(data_a_packed16[ib].d);
const float m = float(data_a_packed16[ib].m);
const uint uint_qh = data_a_packed16[ib].qh;
const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
const vec2 dm = vec2(data_a_packed32[ib].dm);
const uint uint_qh = data_a_packed32[ib].qh;
const uvec2 qh0 = uvec2(((uint_qh >> 4*iqs) << 4) & 0x10, (uint_qh >> (4*iqs + 12)) & 0x10);
const uvec2 qh1 = uvec2(((uint_qh >> (4*iqs + 1)) << 4) & 0x10, (uint_qh >> (4*iqs + 13)) & 0x10);
const uvec2 qh2 = uvec2(((uint_qh >> (4*iqs + 2)) << 4) & 0x10, (uint_qh >> (4*iqs + 14)) & 0x10);
const uvec2 qh3 = uvec2(((uint_qh >> (4*iqs + 3)) << 4) & 0x10, (uint_qh >> (4*iqs + 15)) & 0x10);
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m;
const uint vui = data_a_packed32[ib].qs[iqs];
const vec4 v0 = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) * dm.x + dm.y;
const vec4 v1 = vec4(((vui >> 16) & 0xF) | qh2.x, ((vui >> 20) & 0xF) | qh2.y, ((vui >> 24) & 0xF) | qh3.x, ((vui >> 28) & 0xF) | qh3.y) * dm.x + dm.y;
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz);
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw);
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xz);
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v1.xz);
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v0.yw);
buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.yw);
#elif defined(DATA_A_Q8_0)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@ -131,20 +134,21 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127
const uint ib = idx / 64; // 4 values per idx
const uint iqs = (idx % 64) * 2; // 0,2,4..126
const uint qsi = (iqs / 64) * 16 + (iqs % 16); // 0..15
const uint scalesi = iqs / 8; // 0..15
const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
const uvec2 qs = uvec2(unpack8(data_a_packed16[ib].qs[qsi]));
const vec4 qs = vec4(unpack8((data_a_packed32[ib].qs[qsi / 2] >> qsshift) & 0x03030303));
const uint scales = data_a[ib].scales[scalesi];
const vec2 dm = vec2(data_a[ib].dm);
const vec2 v = dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
const vec4 v = dm.x * float(scales & 0xF) * qs - dm.y * float(scales >> 4);
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);
#elif defined(DATA_A_Q3_K)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@ -173,8 +177,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127
const uint ib = idx / 64; // 4 values per idx
const uint iqs = (idx % 64) * 2; // 0,2,4..126
const uint n = iqs / 32; // 0,1,2,3
const uint b = (iqs % 32) / 16; // 0,1
@ -200,16 +204,16 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const float d = loadd.x * sc;
const float m = -loadd.y * mbyte;
const vec2 q = vec2(unpack8((uint(data_a_packed16[ib].qs[qsi / 2]) >> (b * 4)) & 0x0F0F).xy);
const vec4 q = vec4(unpack8((data_a_packed32[ib].qs[qsi / 4] >> (b * 4)) & 0x0F0F0F0F));
buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, q.x, m),
fma(d, q.y, m));
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m));
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m));
#elif defined(DATA_A_Q5_K)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127
const uint ib = idx / 64; // 4 values per idx
const uint iqs = (idx % 64) * 2; // 0,2,4..126
const uint n = iqs / 32; // 0,1,2,3
const uint b = (iqs % 32) / 16; // 0,1
@ -236,12 +240,12 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const float d = loadd.x * sc;
const float m = -loadd.y * mbyte;
const uint qs = (uint(data_a_packed16[ib].qs[qsi / 2]) >> (b * 4)) & 0x0F0F;
const uint qh = ((uint(data_a_packed16[ib].qh[qhi / 2]) >> (iqs / 16)) & 0x0101) << 4;
const vec2 q = vec2(unpack8(qs | qh).xy);
const uint qs = (data_a_packed32[ib].qs[qsi / 4] >> (b * 4)) & 0x0F0F0F0F;
const uint qh = ((data_a_packed32[ib].qh[qhi / 4] >> (iqs / 16)) & 0x01010101) << 4;
const vec4 q = vec4(unpack8(qs | qh));
buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, q.x, m),
fma(d, q.y, m));
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m));
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m));
#elif defined(DATA_A_Q6_K)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@ -455,7 +459,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
#elif defined(DATA_A_IQ4_NL)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
const uint ib = idx / 8;
const uint iqs = idx & 0x07;
@ -469,7 +473,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
kvalues_iq4nl[vui >> 12]);
#elif defined(DATA_A_MXFP4)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
const uint ib = idx / 8;
const uint iqs = (idx & 0x07) * 2;

View File

@ -1,6 +1,7 @@
#version 450
#extension GL_EXT_control_flow_attributes : require
#extension GL_KHR_shader_subgroup_basic : enable
#if USE_SUBGROUP_ADD
#extension GL_KHR_shader_subgroup_arithmetic : enable
#endif
@ -9,7 +10,8 @@
layout(constant_id = 0) const uint D_STATE = 128;
layout(constant_id = 1) const uint SUBGROUP_SIZE = 32;
layout(constant_id = 2) const uint SPLIT_H = 16;
const uint32_t c_factor = D_STATE / SUBGROUP_SIZE;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
@ -41,22 +43,28 @@ float softplus(float x) {
}
}
shared float stateC[SPLIT_H * D_STATE];
#if !USE_SUBGROUP_ADD
shared float temp[D_STATE];
#endif
void main() {
const uint tid = gl_LocalInvocationID.x;
const uint head_idx = (gl_WorkGroupID.x * SPLIT_H) / d_head;
const uint head_off = ((gl_WorkGroupID.x * SPLIT_H) % d_head) * 4;
const uint seq_idx = gl_WorkGroupID.y;
const uint subgroup = gl_SubgroupID;
const uint lane = gl_SubgroupInvocationID;
const uint tid = gl_SubgroupID * SUBGROUP_SIZE + lane;
const uint subgroup_idx = gl_WorkGroupID.x * c_factor + subgroup;
const uint head_idx = subgroup_idx / d_head;
const uint head_off = (subgroup_idx % d_head) * 4;
const uint seq_idx = gl_WorkGroupID.y;
const uint group_off = (head_idx / (n_head / n_group)) * D_STATE * 4;
const uint s0_base_idx = (uint(ids[seq_idx]) * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
const uint x_base_idx = (seq_idx * nb13 + gl_WorkGroupID.x * SPLIT_H * 4) / 4;
const uint x_base_idx = (seq_idx * nb13 + subgroup_idx * 4) / 4;
const uint dt_base_idx = (seq_idx * nb22 + head_idx * 4) / 4;
const uint A_base_idx = (head_idx * nb31) / 4;
const uint B_base_idx = (seq_idx * nb43 + group_off) / 4;
const uint C_base_idx = (seq_idx * nb53 + group_off) / 4;
const uint y_base_idx = seq_idx * n_tok * n_head * d_head + gl_WorkGroupID.x * SPLIT_H;
const uint y_base_idx = seq_idx * n_tok * n_head * d_head + subgroup_idx;
const uint s_base_idx = (s_off + seq_idx * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
const uint stride_x = nb12 / 4;
@ -65,76 +73,52 @@ void main() {
const uint stride_C = nb52 / 4;
const uint stride_y = n_head * d_head;
float state[SPLIT_H];
[[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
state[j] = s0[s0_base_idx + j * D_STATE + tid];
float state[c_factor];
[[unroll]] for (uint j = 0; j < c_factor; j++) {
state[j] = s0[s0_base_idx + SUBGROUP_SIZE * j + lane];
}
float a = A[A_base_idx];
for (uint i = 0; i < n_tok; i++) {
const float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]);
float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]);
const float dA = exp(dt_soft_plus * A[A_base_idx]);
const float B_val = B[B_base_idx + i * stride_B + tid];
const float C_val = C[C_base_idx + i * stride_C + tid];
[[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
const float x_dt = x[x_base_idx + i * stride_x + j] * dt_soft_plus;
float state_sum = 0.0f;
const float dA = exp(dt_soft_plus * a);
const float x_dt = x[x_base_idx + i * stride_x] * dt_soft_plus;
[[unroll]] for (uint j = 0; j < c_factor; j++) {
float B_val = B[B_base_idx + i * stride_B + SUBGROUP_SIZE * j + lane];
float C_val = C[C_base_idx + i * stride_C + SUBGROUP_SIZE * j + lane];
state[j] = (state[j] * dA) + (B_val * x_dt);
stateC[j * D_STATE + tid] = state[j] * C_val;
state_sum += state[j] * C_val;
}
#if USE_SUBGROUP_ADD
state_sum = subgroupAdd(state_sum);
#else
temp[tid] = state_sum;
barrier();
[[unroll]]
for (uint w = D_STATE / 2; w >= SUBGROUP_SIZE; w >>= 1) {
[[unroll]] for (uint j = 0; j < (w * SPLIT_H + D_STATE - 1) / D_STATE; j++) {
const uint k = (tid % w) + (D_STATE * (tid / w)) + j * D_STATE * (D_STATE / w);
if (k < SPLIT_H * D_STATE && (k + w) < SPLIT_H * D_STATE) {
stateC[k] += stateC[k + w];
}
[[unroll]] for (uint s = SUBGROUP_SIZE / 2; s > 0; s >>= 1) {
if (lane < s) {
temp[tid] += temp[tid + s];
}
barrier();
}
[[unroll]] for (uint j = 0; j < max(1, SPLIT_H / (D_STATE / SUBGROUP_SIZE)); j++) {
const uint idx = (tid % SUBGROUP_SIZE) +
D_STATE * (tid / SUBGROUP_SIZE) +
j * D_STATE * (D_STATE / SUBGROUP_SIZE);
const uint max_idx = SUBGROUP_SIZE - 1 +
D_STATE * ((D_STATE - 1) / SUBGROUP_SIZE) +
j * D_STATE * (D_STATE / SUBGROUP_SIZE);
if (idx < SPLIT_H * D_STATE ||
max_idx < SPLIT_H * D_STATE) {
float sc;
#if USE_SUBGROUP_ADD
sc = stateC[idx];
sc = subgroupAdd(sc);
#else
[[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) {
if (idx + offset < SPLIT_H * D_STATE) {
stateC[idx] += stateC[idx + offset];
}
barrier();
}
if (tid % SUBGROUP_SIZE == 0) {
sc = stateC[idx];
}
// get the value from lane 0
state_sum = temp[subgroup * SUBGROUP_SIZE];
barrier();
#endif
if (tid % SUBGROUP_SIZE == 0) {
const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
d[y_base_idx + i * stride_y + k] = sc;
}
}
if (lane == 0) {
d[y_base_idx + i * stride_y] = state_sum;
}
barrier();
}
[[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
d[s_base_idx + j * D_STATE + tid] = state[j];
// write back the state
[[unroll]]
for (int j = 0; j < c_factor; j++) {
d[s_base_idx + SUBGROUP_SIZE * j + lane] = state[j];
}
}

View File

@ -552,9 +552,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
for (const auto& tname : type_names) {
std::string load_vec_quant = "2";
if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
if ((tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
load_vec_quant = "8";
else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4"))
else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4"))
load_vec_quant = "4";
if (tname == "bf16") {

View File

@ -0,0 +1,169 @@
#ifndef GGML_WEBGPU_SHADER_LIB_HPP
#define GGML_WEBGPU_SHADER_LIB_HPP
#include "ggml.h"
#include "pre_wgsl.hpp"
#include <string>
#include <vector>
#define GGML_WEBGPU_F16_SIZE_BYTES 2
#define GGML_WEBGPU_F32_SIZE_BYTES 4
#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES 8u
#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE 128u
// Matches GGML_PAD(..., 256) in src/llama-context.cpp for KV cache sizing.
#define GGML_WEBGPU_KV_SEQ_PAD 256u
struct ggml_webgpu_flash_attn_shader_lib_context {
ggml_type kv_type;
uint32_t head_dim_qk;
uint32_t head_dim_v;
bool kv_direct;
bool has_mask;
bool has_sinks;
bool uses_logit_softcap;
uint32_t sg_mat_m;
uint32_t sg_mat_n;
uint32_t sg_mat_k;
size_t wg_mem_limit_bytes;
uint32_t max_subgroup_size;
};
struct ggml_webgpu_flash_attn_shader_decisions {
uint32_t q_tile = 0;
uint32_t kv_tile = 0;
uint32_t wg_size = 0;
};
struct ggml_webgpu_processed_shader {
std::string wgsl;
std::string variant;
ggml_webgpu_flash_attn_shader_decisions decisions;
};
// This is exposed because it's necessary in supports_op
inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
uint32_t kv_tile,
uint32_t head_dim_qk,
uint32_t head_dim_v,
bool has_mask,
bool kv_direct) {
const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v);
size_t f16_elems = 0;
size_t f32_elems = 0;
f16_elems += q_tile * head_dim_qk; // q_shmem
if (!kv_direct) {
f16_elems += kv_tile * max_head_dim; // kv_shmem
}
f16_elems += q_tile * head_dim_v; // o_shmem
if (has_mask) {
f16_elems += q_tile * kv_tile; // mask_shmem
}
f16_elems += q_tile * kv_tile; // inter_shmem
f32_elems += q_tile; // row_max_shmem
f32_elems += q_tile; // exp_sum_shmem
return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES;
}
static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) {
const size_t limit_bytes = context.wg_mem_limit_bytes;
const size_t q_tile = context.sg_mat_m;
const size_t base_q_bytes = (context.head_dim_qk + context.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
size_t bytes_per_kv = 0;
if (!context.kv_direct) {
bytes_per_kv += std::max(context.head_dim_qk, context.head_dim_v);
}
if (context.has_mask) {
bytes_per_kv += q_tile;
}
bytes_per_kv += q_tile;
bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n;
}
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader(
pre_wgsl::Preprocessor & preprocessor,
const char * shader_src,
const ggml_webgpu_flash_attn_shader_lib_context & context) {
std::vector<std::string> defines;
std::string variant = "flash_attn";
switch (context.kv_type) {
case GGML_TYPE_F32:
defines.push_back("KV_F32");
break;
case GGML_TYPE_F16:
defines.push_back("KV_F16");
break;
case GGML_TYPE_Q4_0:
defines.push_back("KV_Q4_0");
break;
case GGML_TYPE_Q8_0:
defines.push_back("KV_Q8_0");
break;
default:
GGML_ABORT("Unsupported KV type for flash attention shader");
}
variant += std::string("_") + ggml_type_name(context.kv_type);
if (context.has_mask) {
defines.push_back("MASK");
variant += "_mask";
}
if (context.has_sinks) {
defines.push_back("SINKS");
variant += "_sinks";
}
if (context.uses_logit_softcap) {
defines.push_back("LOGIT_SOFTCAP");
variant += "_lgsc";
}
if (context.kv_direct) {
defines.push_back("KV_DIRECT");
variant += "_kvdirect";
}
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.head_dim_qk));
variant += std::string("_hsqk") + std::to_string(context.head_dim_qk);
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.head_dim_v));
variant += std::string("_hsv") + std::to_string(context.head_dim_v);
// For now these are not part of the variant name
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
// Add chosen Q/KV tile sizes
uint32_t q_tile = context.sg_mat_m;
uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context),
context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
if (context.kv_direct) {
GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
// Avoids having to use bounds-checks and decreasing performance for direct KV loads
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
kv_tile -= context.sg_mat_n;
}
}
defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile));
defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile));
// workgroup size
uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
result.decisions.q_tile = q_tile;
result.decisions.kv_tile = kv_tile;
result.decisions.wg_size = wg_size;
return result;
}
#endif // GGML_WEBGPU_SHADER_LIB_HPP

View File

@ -7,7 +7,9 @@
#include "ggml-backend-impl.h"
#include "ggml-impl.h"
#include "ggml-webgpu-shader-lib.hpp"
#include "ggml-wgsl-shaders.hpp"
#include "pre_wgsl.hpp"
#ifdef __EMSCRIPTEN__
# include <emscripten/emscripten.h>
@ -30,7 +32,7 @@
#ifdef GGML_WEBGPU_DEBUG
# define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl
# define WEBGPU_DEBUG_BUF_ELEMS 32
# define WEBGPU_DEBUG_BUF_ELEMS 512
#else
# define WEBGPU_LOG_DEBUG(msg) ((void) 0)
#endif // GGML_WEBGPU_DEBUG
@ -251,6 +253,7 @@ struct webgpu_gpu_profile_buf_pool {
struct webgpu_pipeline {
wgpu::ComputePipeline pipeline;
std::string name;
void * context = nullptr;
};
struct webgpu_command {
@ -263,6 +266,46 @@ struct webgpu_command {
#endif
};
struct flash_attn_pipeline_key {
int q_type;
int kv_type;
int dst_type;
uint32_t head_dim_qk;
uint32_t head_dim_v;
bool kv_direct;
bool has_mask;
bool has_sinks;
bool uses_logit_softcap;
bool operator==(const flash_attn_pipeline_key & other) const {
return q_type == other.q_type && kv_type == other.kv_type && dst_type == other.dst_type &&
head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct &&
has_mask == other.has_mask && has_sinks == other.has_sinks &&
uses_logit_softcap == other.uses_logit_softcap;
}
};
// Same hash combine function as in boost
template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) {
seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
struct flash_attn_pipeline_key_hash {
size_t operator()(const flash_attn_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.q_type);
ggml_webgpu_hash_combine(seed, key.kv_type);
ggml_webgpu_hash_combine(seed, key.dst_type);
ggml_webgpu_hash_combine(seed, key.head_dim_qk);
ggml_webgpu_hash_combine(seed, key.head_dim_v);
ggml_webgpu_hash_combine(seed, key.kv_direct);
ggml_webgpu_hash_combine(seed, key.has_mask);
ggml_webgpu_hash_combine(seed, key.has_sinks);
ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
return seed;
}
};
// All the base objects needed to run operations on a WebGPU device
struct webgpu_context_struct {
wgpu::Instance instance;
@ -271,12 +314,12 @@ struct webgpu_context_struct {
wgpu::Queue queue;
wgpu::Limits limits;
uint32_t subgroup_size;
uint32_t max_subgroup_size;
#ifndef __EMSCRIPTEN__
bool supports_subgroup_matrix = false;
wgpu::SubgroupMatrixConfig subgroup_matrix_config;
#endif
bool supports_subgroup_matrix = false;
uint32_t sg_mat_m;
uint32_t sg_mat_n;
uint32_t sg_mat_k;
std::recursive_mutex mutex;
std::atomic_uint inflight_threads = 0;
@ -284,20 +327,24 @@ struct webgpu_context_struct {
webgpu_buf_pool param_buf_pool;
webgpu_buf_pool set_rows_error_buf_pool;
pre_wgsl::Preprocessor p;
std::map<int, webgpu_pipeline> memset_pipelines; // variant or type index
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> mul_mat_pipelines; // src0_type, src1_type, vectorized
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>>
mul_mat_vec_pipelines; // src0_type, src1_type, vectorized
std::map<int, std::map<int, webgpu_pipeline>> set_rows_pipelines; // dst_type, vectorized
std::map<int, std::map<int, webgpu_pipeline>> get_rows_pipelines; // src_type, vectorized
std::unordered_map<flash_attn_pipeline_key, webgpu_pipeline, flash_attn_pipeline_key_hash> flash_attn_pipelines;
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
std::map<int, std::map<int, webgpu_pipeline>> add_pipelines; // type, inplace
std::map<int, std::map<int, webgpu_pipeline>> sub_pipelines; // type, inplace
std::map<int, std::map<int, webgpu_pipeline>> mul_pipelines; // type, inplace
std::map<int, std::map<int, webgpu_pipeline>> div_pipelines; // type, inplace
std::map<int, std::map<int, webgpu_pipeline>> set_rows_pipelines; // dst_type, vectorized
std::map<int, std::map<int, webgpu_pipeline>> get_rows_pipelines; // src_type, vectorized
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
std::map<int, std::map<int, webgpu_pipeline>> add_pipelines; // type, inplace
std::map<int, std::map<int, webgpu_pipeline>> sub_pipelines; // type, inplace
std::map<int, std::map<int, webgpu_pipeline>> mul_pipelines; // type, inplace
std::map<int, std::map<int, webgpu_pipeline>> div_pipelines; // type, inplace
std::map<int, webgpu_pipeline> rms_norm_pipelines; // inplace
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace
@ -361,8 +408,6 @@ struct ggml_backend_webgpu_buffer_context {
label(std::move(lbl)) {}
};
/* End struct definitions */
/* WebGPU object initializations */
// Process a WGSL shader string, replacing tokens of the form {{KEY}} with
@ -484,14 +529,9 @@ static void ggml_backend_webgpu_debug(webgpu_context & ctx) {
encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize());
wgpu::CommandBuffer commands = encoder.Finish();
ctx->queue.Submit(1, &commands);
ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize());
const uint32_t * debug_data = (const uint32_t *) ctx->debug_host_buf.GetConstMappedRange();
std::cout << "debug data:";
for (size_t i = 0; i < WEBGPU_DEBUG_BUF_ELEMS; i++) {
std::cout << " " << i << ": " << debug_data[i];
}
std::cout << "\n";
const float * debug_data = (const float *) ctx->debug_host_buf.GetConstMappedRange();
std::cout << "debug[0]: " << debug_data[0] << "\n";
ctx->debug_host_buf.Unmap();
}
#endif
@ -673,6 +713,7 @@ static const char * ggml_backend_webgpu_name(ggml_backend_t backend) {
return ctx->name.c_str();
}
// TODO: implement proper cleanup
static void ggml_backend_webgpu_free(ggml_backend_t backend) {
ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")");
@ -730,12 +771,12 @@ static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) {
return ctx->buffer;
}
static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, ggml_tensor * t) {
static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) {
size_t offset = ggml_webgpu_tensor_offset(t);
return offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
}
static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, ggml_tensor * t) {
static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) {
size_t offset = ggml_webgpu_tensor_offset(t);
return offset & ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
}
@ -964,12 +1005,10 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
#ifndef __EMSCRIPTEN__
if (ctx->supports_subgroup_matrix) {
// The total number of subgroups/workgroups needed per matrix.
uint32_t wg_m_sg_tile =
WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->subgroup_matrix_config.M;
wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile);
uint32_t wg_n_sg_tile =
WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->subgroup_matrix_config.N;
wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile);
uint32_t wg_m_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->sg_mat_m;
wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile);
uint32_t wg_n_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->sg_mat_n;
wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile);
} else {
#endif
uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M;
@ -986,6 +1025,146 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
}
static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
ggml_tensor * Q,
ggml_tensor * K,
ggml_tensor * V,
ggml_tensor * mask,
ggml_tensor * sinks,
ggml_tensor * dst) {
float scale = *(float *) dst->op_params;
float max_bias;
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
float logit_softcap;
memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
if (logit_softcap != 0.0f) {
scale /= logit_softcap;
}
float n_head_log2 = float(1u << (uint32_t) floor(log2(Q->ne[2])));
float m0 = powf(2.0f, -(max_bias) / n_head_log2);
float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
const int has_mask = (mask != nullptr);
const int has_sinks = (sinks != nullptr);
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)),
has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0,
has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) Q->ne[2], // number of heads
(uint32_t) Q->ne[1], // sequence length (Q)
(uint32_t) K->ne[1], // sequence length (K/V)
(uint32_t) (Q->nb[1] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 1
(uint32_t) (Q->nb[2] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 2
(uint32_t) (Q->nb[3] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 3
(uint32_t) (K->nb[1] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 1
(uint32_t) (K->nb[2] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 2
(uint32_t) (K->nb[3] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 3
(uint32_t) (V->nb[1] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 1
(uint32_t) (V->nb[2] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 2
(uint32_t) (V->nb[3] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 3
has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3
(uint32_t) (Q->ne[2] / K->ne[2]), // repeat factor for K/V in dim 2 (MHA/MQA/GQA)
*(uint32_t *) &scale, // scale (possibly adjusted for logit softcap)
*(uint32_t *) &max_bias,
*(uint32_t *) &logit_softcap,
*(uint32_t *) &n_head_log2,
*(uint32_t *) &m0,
*(uint32_t *) &m1
};
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(Q),
.offset = ggml_webgpu_tensor_align_offset(ctx, Q),
.size = ggml_webgpu_tensor_binding_size(ctx, Q) },
{ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(K),
.offset = ggml_webgpu_tensor_align_offset(ctx, K),
.size = ggml_webgpu_tensor_binding_size(ctx, K) },
{ .binding = 2,
.buffer = ggml_webgpu_tensor_buf(V),
.offset = ggml_webgpu_tensor_align_offset(ctx, V),
.size = ggml_webgpu_tensor_binding_size(ctx, V) }
};
uint32_t binding_index = 3;
if (has_mask) {
entries.push_back({ .binding = binding_index++,
.buffer = ggml_webgpu_tensor_buf(mask),
.offset = ggml_webgpu_tensor_align_offset(ctx, mask),
.size = ggml_webgpu_tensor_binding_size(ctx, mask) });
}
if (has_sinks) {
entries.push_back({ .binding = binding_index++,
.buffer = ggml_webgpu_tensor_buf(sinks),
.offset = ggml_webgpu_tensor_align_offset(ctx, sinks),
.size = ggml_webgpu_tensor_binding_size(ctx, sinks) });
}
entries.push_back({ .binding = binding_index++,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
bool kv_direct =
(K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->sg_mat_k == 0) && (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
flash_attn_pipeline_key key = {
.q_type = Q->type,
.kv_type = K->type,
.dst_type = dst->type,
.head_dim_qk = (uint32_t) Q->ne[0],
.head_dim_v = (uint32_t) V->ne[0],
.kv_direct = kv_direct,
.has_mask = static_cast<bool>(has_mask),
.has_sinks = static_cast<bool>(has_sinks),
.uses_logit_softcap = logit_softcap != 0.0f,
};
webgpu_pipeline pipeline;
ggml_webgpu_flash_attn_shader_decisions decisions = {};
auto it = ctx->flash_attn_pipelines.find(key);
if (it != ctx->flash_attn_pipelines.end()) {
pipeline = it->second;
decisions = *static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context);
} else {
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
it = ctx->flash_attn_pipelines.find(key);
if (it != ctx->flash_attn_pipelines.end()) {
pipeline = it->second;
decisions = *static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context);
} else {
ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { .kv_type = K->type,
.head_dim_qk = (uint32_t) Q->ne[0],
.head_dim_v = (uint32_t) V->ne[0],
.kv_direct = kv_direct,
.has_mask = static_cast<bool>(has_mask),
.has_sinks = static_cast<bool>(has_sinks),
.uses_logit_softcap = logit_softcap != 0.0f,
.sg_mat_m = ctx->sg_mat_m,
.sg_mat_n = ctx->sg_mat_n,
.sg_mat_k = ctx->sg_mat_k,
.wg_mem_limit_bytes =
ctx->limits.maxComputeWorkgroupStorageSize,
.max_subgroup_size = ctx->max_subgroup_size };
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx);
pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
pipeline.context = new ggml_webgpu_flash_attn_shader_decisions(processed.decisions);
ctx->flash_attn_pipelines.emplace(key, pipeline);
decisions = processed.decisions;
}
}
uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions.q_tile);
uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
}
static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
uint32_t ne = (uint32_t) ggml_nelements(dst);
ggml_unary_op unary_op = ggml_get_unary_op(dst);
@ -1397,6 +1576,8 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
return ggml_webgpu_get_rows(ctx, src0, src1, node);
case GGML_OP_MUL_MAT:
return ggml_webgpu_mul_mat(ctx, src0, src1, node);
case GGML_OP_FLASH_ATTN_EXT:
return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node);
case GGML_OP_ADD:
{
int inplace = ggml_webgpu_tensor_equal(src0, node);
@ -1466,6 +1647,7 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
webgpu_submission_futures new_futures = ggml_backend_webgpu_submit(ctx, commands);
futures.push_back(new_futures);
}
ggml_backend_webgpu_wait(ctx, futures);
ctx->inflight_threads--;
WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx);
@ -1808,15 +1990,15 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
#ifndef __EMSCRIPTEN__
if (webgpu_ctx->supports_subgroup_matrix) {
std::map<std::string, std::string> sg_matrix_repls;
sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->subgroup_size);
sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->max_subgroup_size);
sg_matrix_repls["WEBGPU_TILE_K"] = std::to_string(WEBGPU_MUL_MAT_TILE_K);
sg_matrix_repls["WEBGPU_SUBGROUP_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M);
sg_matrix_repls["WEBGPU_SUBGROUP_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N);
sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M);
sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N);
sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.M);
sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.N);
sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.K);
sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->sg_mat_m);
sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->sg_mat_n);
sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->sg_mat_k);
proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls);
proc_mul_mat_f32_f32_vec =
@ -2328,6 +2510,7 @@ static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace, "soft_max_f32_mask_f16_sink_inplace", constants);
}
// TODO: move most initialization logic here
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
GGML_UNUSED(params);
@ -2489,6 +2672,29 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
}
break;
}
case GGML_OP_FLASH_ATTN_EXT:
{
if (!webgpu_ctx->supports_subgroup_matrix) {
break;
}
// Head dimensions must fit in workgroup memory with minimum tile sizes
size_t limit_bytes = webgpu_ctx->limits.maxComputeWorkgroupStorageSize;
const bool has_mask = op->src[3] != nullptr;
const bool kv_direct = src1->type == GGML_TYPE_F16 && (src0->ne[0] % webgpu_ctx->sg_mat_k) == 0 &&
(src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
webgpu_ctx->sg_mat_m, webgpu_ctx->sg_mat_n, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0],
has_mask, kv_direct);
if (min_bytes > limit_bytes) {
break;
}
supports_op = src0->type == GGML_TYPE_F32 &&
(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 ||
src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) &&
src2->type == src1->type && op->type == GGML_TYPE_F32;
break;
}
case GGML_OP_RMS_NORM:
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
break;
@ -2606,6 +2812,7 @@ static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) {
}
// TODO: Does this need to be thread safe? Is it only called once?
// TODO: move most logic to device_init function so backend can be freed/initialized properly
// Only one device is supported for now
static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) {
GGML_ASSERT(index == 0);
@ -2665,7 +2872,9 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&
config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
ctx->subgroup_matrix_config = config;
ctx->sg_mat_m = config.M;
ctx->sg_mat_n = config.N;
ctx->sg_mat_k = config.K;
valid_subgroup_matrix_config = true;
break;
}
@ -2676,7 +2885,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
#endif
// For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
// Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
ctx->subgroup_size = info.subgroupMaxSize;
ctx->max_subgroup_size = info.subgroupMaxSize;
// Initialize device
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16 };
@ -2701,8 +2910,11 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
wgpu::CallbackMode::AllowSpontaneous,
[](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
GGML_UNUSED(device);
GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
std::string(message).c_str());
GGML_UNUSED(reason);
GGML_UNUSED(message);
//TODO: uncomment once proper free logic is in place
//GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
//std::string(message).c_str());
});
dev_desc.SetUncapturedErrorCallback(
[](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {

View File

@ -0,0 +1,778 @@
#ifndef PRE_WGSL_HPP
#define PRE_WGSL_HPP
#include <cctype>
#include <fstream>
#include <sstream>
#include <stdexcept>
#include <string>
#include <string_view>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace pre_wgsl {
//==============================================================
// Options
//==============================================================
struct Options {
std::string include_path = ".";
std::vector<std::string> macros;
};
//==============================================================
// Utility: trim
//==============================================================
static std::string trim(const std::string & s) {
size_t a = 0;
while (a < s.size() && std::isspace((unsigned char) s[a])) {
a++;
}
size_t b = s.size();
while (b > a && std::isspace((unsigned char) s[b - 1])) {
b--;
}
return s.substr(a, b - a);
}
static std::string trim_value(std::istream & is) {
std::string str;
std::getline(is, str);
return trim(str);
}
static bool isIdentChar(char c) {
return std::isalnum(static_cast<unsigned char>(c)) || c == '_';
}
static std::string expandMacrosRecursiveInternal(const std::string & line,
const std::unordered_map<std::string, std::string> & macros,
std::unordered_set<std::string> & visiting);
static std::string expandMacroValue(const std::string & name,
const std::unordered_map<std::string, std::string> & macros,
std::unordered_set<std::string> & visiting) {
if (visiting.count(name)) {
throw std::runtime_error("Recursive macro: " + name);
}
visiting.insert(name);
auto it = macros.find(name);
if (it == macros.end()) {
visiting.erase(name);
return name;
}
const std::string & value = it->second;
if (value.empty()) {
visiting.erase(name);
return "";
}
std::string expanded = expandMacrosRecursiveInternal(value, macros, visiting);
visiting.erase(name);
return expanded;
}
static std::string expandMacrosRecursiveInternal(const std::string & line,
const std::unordered_map<std::string, std::string> & macros,
std::unordered_set<std::string> & visiting) {
std::string result;
result.reserve(line.size());
size_t i = 0;
while (i < line.size()) {
if (isIdentChar(line[i])) {
size_t start = i;
while (i < line.size() && isIdentChar(line[i])) {
i++;
}
std::string token = line.substr(start, i - start);
auto it = macros.find(token);
if (it != macros.end()) {
result += expandMacroValue(token, macros, visiting);
} else {
result += token;
}
} else {
result += line[i];
i++;
}
}
return result;
}
static std::string expandMacrosRecursive(const std::string & line,
const std::unordered_map<std::string, std::string> & macros) {
std::unordered_set<std::string> visiting;
return expandMacrosRecursiveInternal(line, macros, visiting);
}
//==============================================================
// Tokenizer for expressions in #if/#elif
//==============================================================
class ExprLexer {
public:
enum Kind { END, IDENT, NUMBER, OP, LPAREN, RPAREN };
struct Tok {
Kind kind;
std::string text;
};
explicit ExprLexer(std::string_view sv) : src(sv), pos(0) {}
Tok next() {
skipWS();
if (pos >= src.size()) {
return { END, "" };
}
char c = src[pos];
// number
if (std::isdigit((unsigned char) c)) {
size_t start = pos;
while (pos < src.size() && std::isdigit((unsigned char) src[pos])) {
pos++;
}
return { NUMBER, std::string(src.substr(start, pos - start)) };
}
// identifier
if (std::isalpha((unsigned char) c) || c == '_') {
size_t start = pos;
while (pos < src.size() && (std::isalnum((unsigned char) src[pos]) || src[pos] == '_')) {
pos++;
}
return { IDENT, std::string(src.substr(start, pos - start)) };
}
if (c == '(') {
pos++;
return { LPAREN, "(" };
}
if (c == ')') {
pos++;
return { RPAREN, ")" };
}
// multi-char operators
static const char * two_ops[] = { "==", "!=", "<=", ">=", "&&", "||", "<<", ">>" };
for (auto op : two_ops) {
if (src.substr(pos, 2) == op) {
pos += 2;
return { OP, std::string(op) };
}
}
// single-char operators
if (std::string("+-*/%<>!").find(c) != std::string::npos) {
pos++;
return { OP, std::string(1, c) };
}
// unexpected
pos++;
return { END, "" };
}
private:
std::string_view src;
size_t pos;
void skipWS() {
while (pos < src.size() && std::isspace((unsigned char) src[pos])) {
pos++;
}
}
};
//==============================================================
// Expression Parser (recursive descent)
//==============================================================
class ExprParser {
public:
ExprParser(std::string_view expr,
const std::unordered_map<std::string, std::string> & macros,
std::unordered_set<std::string> & visiting) :
lex(expr),
macros(macros),
visiting(visiting) {
advance();
}
int parse() { return parseLogicalOr(); }
private:
ExprLexer lex;
ExprLexer::Tok tok;
const std::unordered_map<std::string, std::string> & macros;
std::unordered_set<std::string> & visiting;
void advance() { tok = lex.next(); }
bool acceptOp(const std::string & s) {
if (tok.kind == ExprLexer::OP && tok.text == s) {
advance();
return true;
}
return false;
}
bool acceptKind(ExprLexer::Kind k) {
if (tok.kind == k) {
advance();
return true;
}
return false;
}
int parseLogicalOr() {
int v = parseLogicalAnd();
while (acceptOp("||")) {
int rhs = parseLogicalAnd();
v = (v || rhs);
}
return v;
}
int parseLogicalAnd() {
int v = parseEquality();
while (acceptOp("&&")) {
int rhs = parseEquality();
v = (v && rhs);
}
return v;
}
int parseEquality() {
int v = parseRelational();
for (;;) {
if (acceptOp("==")) {
int rhs = parseRelational();
v = (v == rhs);
} else if (acceptOp("!=")) {
int rhs = parseRelational();
v = (v != rhs);
} else {
break;
}
}
return v;
}
int parseRelational() {
int v = parseShift();
for (;;) {
if (acceptOp("<")) {
int rhs = parseShift();
v = (v < rhs);
} else if (acceptOp(">")) {
int rhs = parseShift();
v = (v > rhs);
} else if (acceptOp("<=")) {
int rhs = parseShift();
v = (v <= rhs);
} else if (acceptOp(">=")) {
int rhs = parseShift();
v = (v >= rhs);
} else {
break;
}
}
return v;
}
int parseShift() {
int v = parseAdd();
for (;;) {
if (acceptOp("<<")) {
int rhs = parseAdd();
v = (v << rhs);
} else if (acceptOp(">>")) {
int rhs = parseAdd();
v = (v >> rhs);
} else {
break;
}
}
return v;
}
int parseAdd() {
int v = parseMult();
for (;;) {
if (acceptOp("+")) {
int rhs = parseMult();
v = (v + rhs);
} else if (acceptOp("-")) {
int rhs = parseMult();
v = (v - rhs);
} else {
break;
}
}
return v;
}
int parseMult() {
int v = parseUnary();
for (;;) {
if (acceptOp("*")) {
int rhs = parseUnary();
v = (v * rhs);
} else if (acceptOp("/")) {
int rhs = parseUnary();
v = (rhs == 0 ? 0 : v / rhs);
} else if (acceptOp("%")) {
int rhs = parseUnary();
v = (rhs == 0 ? 0 : v % rhs);
} else {
break;
}
}
return v;
}
int parseUnary() {
if (acceptOp("!")) {
return !parseUnary();
}
if (acceptOp("-")) {
return -parseUnary();
}
if (acceptOp("+")) {
return +parseUnary();
}
return parsePrimary();
}
int parsePrimary() {
// '(' expr ')'
if (acceptKind(ExprLexer::LPAREN)) {
int v = parse();
if (!acceptKind(ExprLexer::RPAREN)) {
throw std::runtime_error("missing ')'");
}
return v;
}
// number
if (tok.kind == ExprLexer::NUMBER) {
int v = std::stoi(tok.text);
advance();
return v;
}
// defined(identifier)
if (tok.kind == ExprLexer::IDENT && tok.text == "defined") {
advance();
if (acceptKind(ExprLexer::LPAREN)) {
if (tok.kind != ExprLexer::IDENT) {
throw std::runtime_error("expected identifier in defined()");
}
std::string name = tok.text;
advance();
if (!acceptKind(ExprLexer::RPAREN)) {
throw std::runtime_error("missing ) in defined()");
}
return macros.count(name) ? 1 : 0;
} else {
// defined NAME
if (tok.kind != ExprLexer::IDENT) {
throw std::runtime_error("expected identifier in defined NAME");
}
std::string name = tok.text;
advance();
return macros.count(name) ? 1 : 0;
}
}
// identifier -> treat as integer, if defined use its value else 0
if (tok.kind == ExprLexer::IDENT) {
std::string name = tok.text;
advance();
auto it = macros.find(name);
if (it == macros.end()) {
return 0;
}
if (it->second.empty()) {
return 1;
}
return evalMacroExpression(name, it->second);
}
// unexpected
return 0;
}
int evalMacroExpression(const std::string & name, const std::string & value) {
if (visiting.count(name)) {
throw std::runtime_error("Recursive macro: " + name);
}
visiting.insert(name);
ExprParser ep(value, macros, visiting);
int v = ep.parse();
visiting.erase(name);
return v;
}
};
//==============================================================
// Preprocessor
//==============================================================
class Preprocessor {
public:
explicit Preprocessor(Options opts = {}) : opts_(std::move(opts)) {
// Treat empty include path as current directory
if (opts_.include_path.empty()) {
opts_.include_path = ".";
}
parseMacroDefinitions(opts_.macros);
}
std::string preprocess_file(const std::string & filename, const std::vector<std::string> & additional_macros = {}) {
std::unordered_map<std::string, std::string> macros;
std::unordered_set<std::string> predefined;
std::unordered_set<std::string> include_stack;
buildMacros(additional_macros, macros, predefined);
std::string result = processFile(filename, macros, predefined, include_stack, DirectiveMode::All);
return result;
}
std::string preprocess(const std::string & contents, const std::vector<std::string> & additional_macros = {}) {
std::unordered_map<std::string, std::string> macros;
std::unordered_set<std::string> predefined;
std::unordered_set<std::string> include_stack;
buildMacros(additional_macros, macros, predefined);
std::string result = processString(contents, macros, predefined, include_stack, DirectiveMode::All);
return result;
}
std::string preprocess_includes_file(const std::string & filename) {
std::unordered_map<std::string, std::string> macros;
std::unordered_set<std::string> predefined;
std::unordered_set<std::string> include_stack;
std::string result = processFile(filename, macros, predefined, include_stack, DirectiveMode::IncludesOnly);
return result;
}
std::string preprocess_includes(const std::string & contents) {
std::unordered_map<std::string, std::string> macros;
std::unordered_set<std::string> predefined;
std::unordered_set<std::string> include_stack;
std::string result = processString(contents, macros, predefined, include_stack, DirectiveMode::IncludesOnly);
return result;
}
private:
Options opts_;
std::unordered_map<std::string, std::string> global_macros;
enum class DirectiveMode { All, IncludesOnly };
struct Cond {
bool parent_active;
bool active;
bool taken;
};
//----------------------------------------------------------
// Parse macro definitions into global_macros
//----------------------------------------------------------
void parseMacroDefinitions(const std::vector<std::string> & macro_defs) {
for (const auto & def : macro_defs) {
size_t eq_pos = def.find('=');
if (eq_pos != std::string::npos) {
// Format: NAME=VALUE
std::string name = trim(def.substr(0, eq_pos));
std::string value = trim(def.substr(eq_pos + 1));
global_macros[name] = value;
} else {
// Format: NAME
std::string name = trim(def);
global_macros[name] = "";
}
}
}
//----------------------------------------------------------
// Build combined macro map and predefined set for a preprocessing operation
//----------------------------------------------------------
void buildMacros(const std::vector<std::string> & additional_macros,
std::unordered_map<std::string, std::string> & macros,
std::unordered_set<std::string> & predefined) {
macros = global_macros;
predefined.clear();
for (const auto & [name, value] : global_macros) {
predefined.insert(name);
}
for (const auto & def : additional_macros) {
size_t eq_pos = def.find('=');
std::string name, value;
if (eq_pos != std::string::npos) {
name = trim(def.substr(0, eq_pos));
value = trim(def.substr(eq_pos + 1));
} else {
name = trim(def);
value = "";
}
// Add to macros map (will override global if same name)
macros[name] = value;
predefined.insert(name);
}
}
//----------------------------------------------------------
// Helpers
//----------------------------------------------------------
std::string loadFile(const std::string & fname) {
std::ifstream f(fname);
if (!f.is_open()) {
throw std::runtime_error("Could not open file: " + fname);
}
std::stringstream ss;
ss << f.rdbuf();
return ss.str();
}
bool condActive(const std::vector<Cond> & cond) const {
if (cond.empty()) {
return true;
}
return cond.back().active;
}
//----------------------------------------------------------
// Process a file
//----------------------------------------------------------
std::string processFile(const std::string & name,
std::unordered_map<std::string, std::string> & macros,
const std::unordered_set<std::string> & predefined_macros,
std::unordered_set<std::string> & include_stack,
DirectiveMode mode) {
if (include_stack.count(name)) {
throw std::runtime_error("Recursive include: " + name);
}
include_stack.insert(name);
std::string shader_code = loadFile(name);
std::string out = processString(shader_code, macros, predefined_macros, include_stack, mode);
include_stack.erase(name);
return out;
}
std::string processIncludeFile(const std::string & fname,
std::unordered_map<std::string, std::string> & macros,
const std::unordered_set<std::string> & predefined_macros,
std::unordered_set<std::string> & include_stack,
DirectiveMode mode) {
std::string full_path = opts_.include_path + "/" + fname;
return processFile(full_path, macros, predefined_macros, include_stack, mode);
}
//----------------------------------------------------------
// Process text
//----------------------------------------------------------
std::string processString(const std::string & shader_code,
std::unordered_map<std::string, std::string> & macros,
const std::unordered_set<std::string> & predefined_macros,
std::unordered_set<std::string> & include_stack,
DirectiveMode mode) {
std::vector<Cond> cond; // Conditional stack for this shader
std::stringstream out;
std::istringstream in(shader_code);
std::string line;
while (std::getline(in, line)) {
std::string t = trim(line);
if (!t.empty() && t[0] == '#') {
bool handled = handleDirective(t, out, macros, predefined_macros, cond, include_stack, mode);
if (mode == DirectiveMode::IncludesOnly && !handled) {
out << line << "\n";
}
} else {
if (mode == DirectiveMode::IncludesOnly) {
out << line << "\n";
} else if (condActive(cond)) {
// Expand macros in the line before outputting
std::string expanded = expandMacrosRecursive(line, macros);
out << expanded << "\n";
}
}
}
if (mode == DirectiveMode::All && !cond.empty()) {
throw std::runtime_error("Unclosed #if directive");
}
return out.str();
}
//----------------------------------------------------------
// Directive handler
//----------------------------------------------------------
bool handleDirective(const std::string & t,
std::stringstream & out,
std::unordered_map<std::string, std::string> & macros,
const std::unordered_set<std::string> & predefined_macros,
std::vector<Cond> & cond,
std::unordered_set<std::string> & include_stack,
DirectiveMode mode) {
// split into tokens
std::string body = t.substr(1);
std::istringstream iss(body);
std::string cmd;
iss >> cmd;
if (cmd == "include") {
if (mode == DirectiveMode::All && !condActive(cond)) {
return true;
}
std::string file;
iss >> file;
if (file.size() >= 2 && file.front() == '"' && file.back() == '"') {
file = file.substr(1, file.size() - 2);
}
out << processIncludeFile(file, macros, predefined_macros, include_stack, mode);
return true;
}
if (mode == DirectiveMode::IncludesOnly) {
return false;
}
if (cmd == "define") {
if (!condActive(cond)) {
return true;
}
std::string name;
iss >> name;
// Don't override predefined macros from options
if (predefined_macros.count(name)) {
return true;
}
std::string value = trim_value(iss);
macros[name] = value;
return true;
}
if (cmd == "undef") {
if (!condActive(cond)) {
return true;
}
std::string name;
iss >> name;
// Don't undef predefined macros from options
if (predefined_macros.count(name)) {
return true;
}
macros.erase(name);
return true;
}
if (cmd == "ifdef") {
std::string name;
iss >> name;
bool p = condActive(cond);
bool v = macros.count(name);
cond.push_back({ p, p && v, p && v });
return true;
}
if (cmd == "ifndef") {
std::string name;
iss >> name;
bool p = condActive(cond);
bool v = !macros.count(name);
cond.push_back({ p, p && v, p && v });
return true;
}
if (cmd == "if") {
std::string expr = trim_value(iss);
bool p = condActive(cond);
bool v = false;
if (p) {
std::unordered_set<std::string> visiting;
ExprParser ep(expr, macros, visiting);
v = ep.parse() != 0;
}
cond.push_back({ p, p && v, p && v });
return true;
}
if (cmd == "elif") {
std::string expr = trim_value(iss);
if (cond.empty()) {
throw std::runtime_error("#elif without #if");
}
Cond & c = cond.back();
if (!c.parent_active) {
c.active = false;
return true;
}
if (c.taken) {
c.active = false;
return true;
}
std::unordered_set<std::string> visiting;
ExprParser ep(expr, macros, visiting);
bool v = ep.parse() != 0;
c.active = v;
if (v) {
c.taken = true;
}
return true;
}
if (cmd == "else") {
if (cond.empty()) {
throw std::runtime_error("#else without #if");
}
Cond & c = cond.back();
if (!c.parent_active) {
c.active = false;
return true;
}
if (c.taken) {
c.active = false;
} else {
c.active = true;
c.taken = true;
}
return true;
}
if (cmd == "endif") {
if (cond.empty()) {
throw std::runtime_error("#endif without #if");
}
cond.pop_back();
return true;
}
// Unknown directive
throw std::runtime_error("Unknown directive: #" + cmd);
}
};
} // namespace pre_wgsl
#endif // PRE_WGSL_HPP

View File

@ -0,0 +1,591 @@
diagnostic(off, chromium.subgroup_matrix_uniformity);
diagnostic(off, subgroup_uniformity);
enable f16;
enable subgroups;
enable chromium_experimental_subgroup_matrix;
#ifdef KV_F32
#define KV_TYPE f32
#else
#define KV_TYPE f16
#endif
// Default values
#define HEAD_DIM_QK 64
#define HEAD_DIM_V 64
// The number of rows/columns/k in a subgroup matrix. MxK * KxN = MxN
// Note that the "K" here does not correspond to the K in attention's Q/K/V, it's just the common dimension.
#define SG_MAT_M 8
#define SG_MAT_N 8
#define SG_MAT_K 8
// Each workgroup processes one subgroup matrix of Q rows
#define Q_TILE SG_MAT_M
#define KV_TILE 16
#define WG_SIZE 64
// Number of subgroup-matrix-width blocks that span the KV tile. SG_MAT_N must divide KV_TILE.
#define KV_BLOCKS (KV_TILE / SG_MAT_N)
// Quantization constants/helpers
#define BLOCK_SIZE 32
#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE)
#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE)
// number of quantized elements processed per thread
#if defined(KV_Q4_0)
#define NQ 16
// Q4_0 has 32 elements, 1 f16 for scale, 8 f16 for 4-bit weights
#define F16_PER_BLOCK 9
#define WEIGHTS_PER_F16 4
#elif defined(KV_Q8_0)
#define NQ 8
// Q8_0 has 32 elements, 1 f16 for scale, 16 f16 for 8-bit weights
#define F16_PER_BLOCK 17
#define WEIGHTS_PER_F16 2
#endif
#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16)
// Ok not to put these in a define block, compiler will remove if unused
fn get_byte(value: u32, index: u32) -> u32 {
return (value >> (index * 8)) & 0xFF;
}
fn get_byte_i32(value: u32, index: u32) -> i32 {
return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
}
struct Params {
offset_q: u32,
offset_k: u32,
offset_v: u32,
offset_mask: u32,
offset_sinks: u32,
offset_dst: u32,
// shapes of Q/K/V
n_heads: u32,
seq_len_q: u32,
seq_len_kv: u32,
// strides (in elements)
stride_q1: u32,
stride_q2: u32,
stride_q3: u32,
stride_k1: u32,
stride_k2: u32,
stride_k3: u32,
stride_v1: u32,
stride_v2: u32,
stride_v3: u32,
stride_mask3: u32,
// repeat factors for K/V, e.g., MHA vs. MQA vs. GQA
q_per_kv: u32,
// softmax params
scale: f32,
max_bias: f32,
logit_softcap: f32,
n_head_log2: f32,
m0: f32,
m1: f32,
};
@group(0) @binding(0) var<storage, read_write> Q: array<f32>;
@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
@group(0) @binding(2) var<storage, read_write> V: array<KV_TYPE>;
#if defined(MASK) && defined(SINKS)
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
@group(0) @binding(4) var<storage, read_write> sinks: array<f32>;
#define DST_BINDING 5
#define PARAMS_BINDING 6
#elif defined(MASK)
@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
#define DST_BINDING 4
#define PARAMS_BINDING 5
#elif defined(SINKS)
@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
#define DST_BINDING 4
#define PARAMS_BINDING 5
#else
#define DST_BINDING 3
#define PARAMS_BINDING 4
#endif
@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<f32>;
@group(0) @binding(PARAMS_BINDING) var<uniform> params: Params;
// Just a very small float value.
const FLOAT_MIN: f32 = -1.0e9;
// The number of Q rows processed per workgroup
var<workgroup> q_shmem: array<f16, Q_TILE * HEAD_DIM_QK>;
#ifndef KV_DIRECT
const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V);
// we can reuse the same shmem for K and V since we only need one at a time
var<workgroup> kv_shmem: array<f16, kv_shmem_size>;
#endif
var<workgroup> o_shmem: array<f16, Q_TILE * HEAD_DIM_V>; // output shmem
#ifdef MASK
// storage for mask values
var<workgroup> mask_shmem: array<f16, Q_TILE * KV_TILE>;
#endif
// storage for output of Q*K^T scores for online softmax (S matrix from paper)
// also storage for diagonal matrix during online softmax (P matrix from paper)
// note that we reuse the same storage for both since we only need one at a time
var<workgroup> inter_shmem: array<f16, Q_TILE * KV_TILE>;
// Storage for row max and exp sum during online softmax
var<workgroup> row_max_shmem: array<f32, Q_TILE>;
var<workgroup> exp_sum_shmem: array<f32, Q_TILE>;
fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32) -> f32 {
var v = select(FLOAT_MIN,
f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]) * params.scale,
kv_idx < KV_TILE);
#ifdef LOGIT_SOFTCAP
v = params.logit_softcap * tanh(v);
#endif
#ifdef MASK
let mask_val = select(0.0, f32(mask_shmem[q_tile_row * KV_TILE + kv_idx]), kv_idx < KV_TILE);
let mask_term = slope * mask_val;
v += mask_term;
#endif
return v;
}
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(subgroup_id) subgroup_id: u32,
@builtin(subgroup_size) subgroup_size: u32,
@builtin(num_subgroups) num_subgroups: u32,
@builtin(subgroup_invocation_id) sg_inv_id: u32) {
// initialize row max for online softmax
for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) {
row_max_shmem[i] = FLOAT_MIN;
exp_sum_shmem[i] = 0.0;
}
for (var i = local_id.x; i < Q_TILE * HEAD_DIM_V; i += WG_SIZE) {
o_shmem[i] = 0.0;
}
// workgroups per head/batch
let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE;
let wg_per_batch = wg_per_head * params.n_heads;
let dst2_stride = HEAD_DIM_V * params.n_heads;
let dst3_stride = dst2_stride * params.seq_len_q;
// batch index
let batch_idx = wg_id.x / wg_per_batch;
let q_batch_offset = params.offset_q + batch_idx * params.stride_q3;
let k_batch_offset = params.offset_k + batch_idx * params.stride_k3;
let v_batch_offset = params.offset_v + batch_idx * params.stride_v3;
let dst_batch_offset = params.offset_dst + batch_idx * dst3_stride;
let wg_in_batch = wg_id.x % wg_per_batch;
// head index
let head_idx = wg_in_batch / wg_per_head;
let q_head_offset = q_batch_offset + head_idx * params.stride_q2;
let k_head_idx = head_idx / params.q_per_kv;
let v_head_idx = k_head_idx;
let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2;
let v_head_offset = v_batch_offset + v_head_idx * params.stride_v2;
// starting Q row for this workgroup
let wg_in_head = wg_in_batch % wg_per_head;
let q_row_start = wg_in_head * Q_TILE;
#ifdef MASK
// mask offset
let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv;
#endif
// note that the output is permuted, the layout is [head_dim_v, n_heads, seq_len_q, batch_size]
let dst_global_offset = dst_batch_offset + q_row_start * dst2_stride + head_idx * HEAD_DIM_V;
let head = f32(head_idx);
let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), params.max_bias > 0);
// load q tile into shared memory
for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) {
let q_row = elem_idx / HEAD_DIM_QK;
let q_col = elem_idx % HEAD_DIM_QK;
let head_q_row = q_row_start + q_row;
let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1;
q_shmem[elem_idx] = f16(select(
0.0,
Q[global_q_row_offset + q_col],
head_q_row < params.seq_len_q && q_col < HEAD_DIM_QK));
}
for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) {
// clear inter_shmem to ensure zero-initialized accumulators
for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
inter_shmem[elem_idx] = 0.0;
}
// load k tile into shared memory
#if defined(KV_Q4_0)
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
let blck_idx = elem_idx / BLOCK_SIZE;
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
let k_row = blck_idx / BLOCKS_K;
let global_k_row = kv_tile + k_row;
let block_k = blck_idx % BLOCKS_K;
let row_offset = k_row * HEAD_DIM_QK;
if (global_k_row < params.seq_len_kv) {
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
let base_idx = global_block_idx * F16_PER_BLOCK;
let d = K[base_idx]; // scale
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = K[base_idx + 1u + block_offset + j];
let q_1 = K[base_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
kv_shmem[row_offset + idx] = q_lo;
kv_shmem[row_offset + idx + 16u] = q_hi;
}
}
}
}
#elif defined(KV_Q8_0)
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
let blck_idx = elem_idx / BLOCK_SIZE;
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
let k_row = blck_idx / BLOCKS_K;
let global_k_row = kv_tile + k_row;
let block_k = blck_idx % BLOCKS_K;
let row_offset = k_row * HEAD_DIM_QK;
if (global_k_row < params.seq_len_kv) {
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
let base_idx = global_block_idx * F16_PER_BLOCK;
let d = K[base_idx]; // scale
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = K[base_idx + 1u + block_offset + j];
let q_1 = K[base_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f16(q_byte) * d;
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
kv_shmem[row_offset + idx] = q_val;
}
}
}
}
#elif defined(KV_DIRECT)
// Direct global loads for KV
#else
for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) {
let k_row = elem_idx / HEAD_DIM_QK;
let k_col = elem_idx % HEAD_DIM_QK;
let global_k_row = kv_tile + k_row;
let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1;
kv_shmem[elem_idx] = f16(select(
0.0,
K[global_k_row_offset + k_col],
global_k_row < params.seq_len_kv && k_col < HEAD_DIM_QK));
}
#endif
workgroupBarrier();
// accumulate q block * k block into registers across the entire KV tile
// TODO: this loop seems to be the current largest bottleneck
for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) {
let inter_offset = kv_block * SG_MAT_N;
var acc: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad<
subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>(&inter_shmem, inter_offset, false, KV_TILE);
#ifdef KV_DIRECT
let k_block_row = kv_tile + kv_block * SG_MAT_N;
let k_global_offset = k_head_offset + k_block_row * params.stride_k1;
#else
let k_block_offset = kv_block * SG_MAT_N * HEAD_DIM_QK;
#endif
for (var head_dim_block = 0u; head_dim_block < HEAD_DIM_QK; head_dim_block += SG_MAT_K) {
// load q submatrix from shared memory
var q_sg_mat: subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(
&q_shmem,
head_dim_block,
false,
HEAD_DIM_QK
);
// load k submatrix from device or shared memory
#ifdef KV_DIRECT
var k_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
&K,
k_global_offset + head_dim_block,
true,
params.stride_k1
);
#else
var k_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
&kv_shmem,
k_block_offset + head_dim_block,
true,
HEAD_DIM_QK
);
#endif
acc = subgroupMatrixMultiplyAccumulate(q_sg_mat, k_sg_mat, acc);
}
// store acc to shared memory for softmax (S matrix from paper)
subgroupMatrixStore(&inter_shmem, inter_offset, acc, false, KV_TILE);
}
#ifdef MASK
// load mask tile into shared memory for this KV block
// TODO: optimize and skip if mask is -INF for the entire tile
for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
let mask_row = elem_idx / KV_TILE;
let mask_col = elem_idx % KV_TILE;
let global_q_row = q_row_start + mask_row;
let global_k_col = kv_tile + mask_col;
let mask_in_bounds = global_q_row < params.seq_len_q && global_k_col < params.seq_len_kv;
let mask_idx = mask_global_offset + mask_row * params.seq_len_kv + global_k_col;
mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds);
}
#endif
workgroupBarrier();
// online softmax
for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) {
let global_q_row = q_row_start + q_tile_row;
if (global_q_row >= params.seq_len_q) {
break;
}
// initialize running max for this row
var prev_max = row_max_shmem[q_tile_row];
var final_max = prev_max;
// pass 1: compute final max across the full KV tile in chunks
for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
let kv_idx = kv_offset + sg_inv_id;
let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope);
final_max = subgroupMax(max(final_max, softmax_term));
}
var total_exp_term: f32 = 0.0;
// pass 2: compute exp sum and write P using final_max
for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
let kv_idx = kv_offset + sg_inv_id;
let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope);
let cur_p = select(0.0,
exp(softmax_term - final_max),
kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE);
total_exp_term += subgroupAdd(cur_p);
if (kv_idx < KV_TILE) {
inter_shmem[kv_idx + q_tile_row * KV_TILE] = f16(cur_p);
}
}
let cur_exp = exp(prev_max - final_max);
if (sg_inv_id == 0) {
row_max_shmem[q_tile_row] = final_max;
exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + total_exp_term;
}
for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
let idx = q_tile_row * HEAD_DIM_V + elem_idx;
o_shmem[idx] = f16(f32(o_shmem[idx]) * cur_exp);
}
}
// load v tile into shared memory
#if defined(KV_Q4_0)
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
let blck_idx = elem_idx / BLOCK_SIZE;
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
let v_row = blck_idx / BLOCKS_V;
let global_v_row = kv_tile + v_row;
let block_k = blck_idx % BLOCKS_V;
let row_offset = v_row * HEAD_DIM_V;
if (global_v_row < params.seq_len_kv) {
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
let base_idx = global_block_idx * F16_PER_BLOCK;
let d = V[base_idx]; // scale
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = V[base_idx + 1u + block_offset + j];
let q_1 = V[base_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
kv_shmem[row_offset + idx] = q_lo;
kv_shmem[row_offset + idx + 16u] = q_hi;
}
}
}
}
#elif defined(KV_Q8_0)
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
let blck_idx = elem_idx / BLOCK_SIZE;
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
let v_row = blck_idx / BLOCKS_V;
let global_v_row = kv_tile + v_row;
let block_k = blck_idx % BLOCKS_V;
let row_offset = v_row * HEAD_DIM_V;
if (global_v_row < params.seq_len_kv) {
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
let base_idx = global_block_idx * F16_PER_BLOCK;
let d = V[base_idx]; // scale
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = V[base_idx + 1u + block_offset + j];
let q_1 = V[base_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f16(q_byte) * d;
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
kv_shmem[row_offset + idx] = q_val;
}
}
}
}
#elif defined(KV_DIRECT)
// Direct global loads for KV
#else
for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE) {
let v_row = elem_idx / HEAD_DIM_V;
let v_col = elem_idx % HEAD_DIM_V;
let global_v_row = kv_tile + v_row;
let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1;
kv_shmem[elem_idx] = f16(select(
0.0,
V[global_v_row_offset + v_col],
global_v_row < params.seq_len_kv && v_col < HEAD_DIM_V));
}
#endif
workgroupBarrier();
// we have P (Q_TILE x KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem
// we want to compute O += P * V across the full KV tile
for (var head_dim_block = subgroup_id * SG_MAT_N;
head_dim_block < HEAD_DIM_V;
head_dim_block += num_subgroups * SG_MAT_N) {
// load O submatrix from shared memory
var o_sg_mat: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>(
&o_shmem,
head_dim_block,
false,
HEAD_DIM_V
);
for (var kv_block = 0u; kv_block < KV_BLOCKS; kv_block++) {
let p_offset = kv_block * SG_MAT_N;
var p_sg_mat: subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(
&inter_shmem,
p_offset,
false,
KV_TILE
);
// load V submatrix from global or shared memory
#ifdef KV_DIRECT
let v_block_row = kv_tile + kv_block * SG_MAT_N;
let v_global_offset = v_head_offset + v_block_row * params.stride_v1 + head_dim_block;
var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
&V,
v_global_offset,
false,
params.stride_v1
);
#else
let v_block_offset = kv_block * SG_MAT_N * HEAD_DIM_V;
var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
&kv_shmem,
v_block_offset + head_dim_block,
false,
HEAD_DIM_V
);
#endif
// O += P * V
o_sg_mat = subgroupMatrixMultiplyAccumulate(p_sg_mat, v_sg_mat, o_sg_mat);
}
// store O back to shared memory
subgroupMatrixStore(&o_shmem, head_dim_block, o_sg_mat, false, HEAD_DIM_V);
}
workgroupBarrier();
}
#ifdef SINKS
// add sinks (applied once after processing all KV tiles)
for (var q_tile_row = subgroup_id;
q_tile_row < Q_TILE;
q_tile_row += num_subgroups) {
// no need to process rows beyond seq_len_q
let global_q_row = q_row_start + q_tile_row;
if (global_q_row >= params.seq_len_q) {
break;
}
var prev_max = row_max_shmem[q_tile_row];
// for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum
let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0);
let new_max = subgroupMax(max(prev_max, sink_val));
let max_exp = exp(prev_max - new_max);
let sink_exp = exp(sink_val - new_max);
let sink_exp_sum = subgroupAdd(sink_exp);
if (sg_inv_id == 0) {
exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * max_exp + sink_exp_sum;
}
for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
let idx = q_tile_row * HEAD_DIM_V + elem_idx;
let val = f32(o_shmem[idx]) * max_exp;
o_shmem[idx] = f16(val);
}
}
workgroupBarrier();
#endif
// write output back to global memory
for (var q_tile_row = subgroup_id;
q_tile_row < Q_TILE;
q_tile_row += num_subgroups) {
let global_q_row = q_row_start + q_tile_row;
if (global_q_row >= params.seq_len_q) {
break;
}
let exp_sum = exp_sum_shmem[q_tile_row];
let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0);
for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
let o_val = o_shmem[q_tile_row * HEAD_DIM_V + elem_idx];
let scaled = f32(o_val) * scale;
dst[dst_global_offset + q_tile_row * dst2_stride + elem_idx] = scaled;
}
}
}

View File

@ -309,6 +309,7 @@ extern "C" {
// Keep the booleans together to avoid misalignment during copy-by-value.
bool vocab_only; // only load the vocabulary, no weights
bool use_mmap; // use mmap if possible
bool use_direct_io; // use direct io, takes precedence over use_mmap
bool use_mlock; // force system to keep model in RAM
bool check_tensors; // validate model tensor data
bool use_extra_bufts; // use extra buffer types (used for weight repacking)
@ -494,7 +495,7 @@ extern "C" {
struct llama_context_params * cparams,
float * tensor_split, // writable buffer for tensor split, needs at least llama_max_devices elements
struct llama_model_tensor_buft_override * tensor_buft_overrides, // writable buffer for overrides, needs at least llama_max_tensor_buft_overrides elements
size_t margin, // margin of memory to leave per device in bytes
size_t * margins, // margins of memory to leave per device in bytes
uint32_t n_ctx_min, // minimum context size to set when trying to reduce memory use
enum ggml_log_level log_level); // minimum log level to print during fitting, lower levels go to debug log

View File

@ -1,26 +0,0 @@
Copyright (c) 2010-2014, Salvatore Sanfilippo <antirez at gmail dot com>
Copyright (c) 2010-2013, Pieter Noordhuis <pcnoordhuis at gmail dot com>
Copyright (c) 2025, Eric Curtin <ericcurtin17 at gmail dot com>
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

67
scripts/pr2wt.sh Executable file
View File

@ -0,0 +1,67 @@
#!/usr/bin/env bash
# intialize a new worktree from a PR number:
#
# - creates a new remote using the fork's clone URL
# - creates a local branch tracking the remote branch
# - creates a new worktree in a parent folder, suffixed with "-pr-${PR}"
#
# sample usage:
# ./scripts/pr2wt.sh 12345
# ./scripts/pr2wt.sh 12345 opencode
# ./scripts/pr2wt.sh 12345 "cmake -B build && cmake --build build"
function usage() {
echo "usage: $0 <pr_number> [cmd]"
exit 1
}
# check we are in the right directory
if [[ ! -f "scripts/pr2wt.sh" ]]; then
echo "error: this script must be run from the root of the repository"
exit 1
fi
if [[ $# -lt 1 || $# -gt 2 ]]; then
usage
fi
PR=$1
[[ "$PR" =~ ^[0-9]+$ ]] || { echo "error: PR number must be numeric"; exit 1; }
url_origin=$(git config --get remote.origin.url) || {
echo "error: no remote named 'origin' in this repository"
exit 1
}
org_repo=$(echo $url_origin | cut -d/ -f4-)
org_repo=${org_repo%.git}
echo "org/repo: $org_repo"
meta=$(curl -sSf -H "Accept: application/vnd.github+json" "https://api.github.com/repos/${org_repo}/pulls/${PR}")
url_remote=$(echo "$meta" | jq -r '.head.repo.clone_url')
head_ref=$(echo "$meta" | jq -r '.head.ref')
echo "url: $url_remote"
echo "head_ref: $head_ref"
git remote rm pr/${PR} 2> /dev/null
git remote add pr/${PR} $url_remote
git fetch pr/${PR} $head_ref
dir=$(basename $(pwd))
git branch -D pr/$PR 2> /dev/null
git worktree add -b pr/$PR ../$dir-pr-$PR pr/$PR/${head_ref} 2> /dev/null
wt_path=$(cd ../$dir-pr-$PR && pwd)
echo "git worktree created in $wt_path"
# if a command was provided, execute it
if [[ $# -eq 2 ]]; then
cd ../$dir-pr-$PR
eval "$2"
fi

View File

@ -16,7 +16,7 @@ vendor = {
# "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.23/miniaudio.h": "vendor/miniaudio/miniaudio.h",
"https://github.com/mackron/miniaudio/raw/669ed3e844524fcd883231b13095baee9f6de304/miniaudio.h": "vendor/miniaudio/miniaudio.h",
"https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.28.0/httplib.h": "vendor/cpp-httplib/httplib.h",
"https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.30.0/httplib.h": "vendor/cpp-httplib/httplib.h",
"https://raw.githubusercontent.com/sheredom/subprocess.h/b49c56e9fe214488493021017bf3954b91c7c1f5/subprocess.h": "vendor/sheredom/subprocess.h",
}

View File

@ -110,7 +110,7 @@ struct llama_file::impl {
}
}
void read_raw(void * ptr, size_t len) const {
void read_raw(void * ptr, size_t len) {
size_t bytes_read = 0;
while (bytes_read < len) {
size_t chunk_size = std::min<size_t>(len - bytes_read, 64*1024*1024);
@ -127,7 +127,7 @@ struct llama_file::impl {
}
}
uint32_t read_u32() const {
uint32_t read_u32() {
uint32_t val;
read_raw(&val, sizeof(val));
return val;
@ -154,8 +154,8 @@ struct llama_file::impl {
write_raw(&val, sizeof(val));
}
void read_aligned_chunk(size_t offset, void * dest, size_t size) const {
throw std::runtime_error("DirectIO is not implemented on Windows.");
bool has_direct_io() const {
return true;
}
~impl() {
@ -164,33 +164,45 @@ struct llama_file::impl {
}
}
#else
impl(const char * fname, const char * mode, [[maybe_unused]] const bool use_direct_io = false) {
impl(const char * fname, const char * mode, [[maybe_unused]] const bool use_direct_io = false) : fname(fname) {
#ifdef __linux__
// Try unbuffered I/O for read only
if (use_direct_io && std::strcmp(mode, "rb") == 0) {
fd = open(fname, O_RDONLY | O_DIRECT);
if (fd != -1) {
struct stat file_stats{};
fstat(fd, &file_stats);
size = file_stats.st_size;
alignment = file_stats.st_blksize;
off_t ret = lseek(fd, 0, SEEK_SET);
if (ret == -1) {
throw std::runtime_error(format("seek error: %s", strerror(errno)));
}
if (init_fd()) {
return;
}
LLAMA_LOG_WARN("Failed to open model %s with error: %s. Falling back to buffered I/O",
fname, strerror(errno));
LLAMA_LOG_WARN("Failed to open file '%s' with error: %s. Falling back to buffered I/O",
fname, strerror(errno));
}
#endif
fp = ggml_fopen(fname, mode);
init_fp(mode);
}
#ifdef __linux__
bool init_fd() {
fd = open(fname.c_str(), O_RDONLY | O_DIRECT);
if (fd != -1) {
struct stat file_stats{};
fstat(fd, &file_stats);
size = file_stats.st_size;
alignment = file_stats.st_blksize;
off_t ret = lseek(fd, 0, SEEK_SET);
if (ret == -1) {
throw std::runtime_error(format("seek error: %s", strerror(errno)));
}
return true;
}
return false;
}
#endif
void init_fp(const char * mode) {
fp = ggml_fopen(fname.c_str(), mode);
if (fp == NULL) {
throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
throw std::runtime_error(format("failed to open %s: %s", fname.c_str(), strerror(errno)));
}
seek(0, SEEK_END);
size = tell();
@ -226,7 +238,7 @@ struct llama_file::impl {
}
}
void read_raw(void * ptr, size_t len) const {
void read_raw_unsafe(void * ptr, size_t len) {
if (len == 0) {
return;
}
@ -249,6 +261,17 @@ struct llama_file::impl {
if (errno == EINTR) {
continue; // Interrupted by signal, retry
}
// Fallback to std::fread in case the DMA controller cannot access the buffer
if (errno == EFAULT) {
auto curr_off = tell();
close(fd);
fd = -1;
alignment = 1;
init_fp("rb");
seek(curr_off, SEEK_SET);
read_raw_unsafe(ptr, len);
return;
}
throw std::runtime_error(format("read error: %s", strerror(errno)));
}
if (ret == 0) {
@ -266,7 +289,8 @@ struct llama_file::impl {
}
}
void read_aligned_chunk(size_t offset, void * dest, size_t size) const {
void read_aligned_chunk(void * dest, size_t size) {
size_t offset = tell();
off_t aligned_offset = offset & ~(alignment - 1);
off_t offset_from_alignment = offset - aligned_offset;
size_t bytes_to_read = (offset_from_alignment + size + alignment - 1) & ~(alignment - 1);
@ -283,13 +307,21 @@ struct llama_file::impl {
std::unique_ptr<void, aligned_buffer_deleter> buffer(raw_buffer);
seek(aligned_offset, SEEK_SET);
read_raw(buffer.get(), bytes_to_read);
read_raw_unsafe(buffer.get(), bytes_to_read);
uintptr_t actual_data = reinterpret_cast<uintptr_t>(buffer.get()) + offset_from_alignment;
memcpy(dest, reinterpret_cast<void *>(actual_data), size);
}
uint32_t read_u32() const {
void read_raw(void * ptr, size_t len) {
if (has_direct_io()) {
read_aligned_chunk(ptr, len);
} else {
read_raw_unsafe(ptr, len);
}
}
uint32_t read_u32() {
uint32_t ret;
read_raw(&ret, sizeof(ret));
return ret;
@ -310,6 +342,10 @@ struct llama_file::impl {
write_raw(&val, sizeof(val));
}
bool has_direct_io() const {
return fd != -1 && alignment > 1;
}
~impl() {
if (fd != -1) {
close(fd);
@ -318,17 +354,9 @@ struct llama_file::impl {
}
}
int fd = -1;
std::string fname;
#endif
void read_raw_at(void * ptr, size_t len, size_t offset) const {
if (alignment != 1) {
read_aligned_chunk(offset, ptr, len);
} else {
seek(offset, SEEK_SET);
read_raw(ptr, len);
}
}
size_t read_alignment() const {
return alignment;
}
@ -347,6 +375,7 @@ size_t llama_file::tell() const { return pimpl->tell(); }
size_t llama_file::size() const { return pimpl->size; }
size_t llama_file::read_alignment() const { return pimpl->read_alignment(); }
bool llama_file::has_direct_io() const { return pimpl->has_direct_io(); }
int llama_file::file_id() const {
#ifdef _WIN32
@ -361,10 +390,14 @@ int llama_file::file_id() const {
}
void llama_file::seek(size_t offset, int whence) const { pimpl->seek(offset, whence); }
void llama_file::read_raw(void * ptr, size_t len) const { pimpl->read_raw(ptr, len); }
void llama_file::read_raw_at(void * ptr, size_t len, size_t offset) const { pimpl->read_raw_at(ptr, len, offset); }
void llama_file::read_raw(void * ptr, size_t len) { pimpl->read_raw(ptr, len); }
#ifdef _WIN32
void llama_file::read_raw_unsafe(void * ptr, size_t len) { pimpl->read_raw(ptr, len); }
#else
void llama_file::read_raw_unsafe(void * ptr, size_t len) { pimpl->read_raw_unsafe(ptr, len); }
#endif
uint32_t llama_file::read_u32() const { return pimpl->read_u32(); }
uint32_t llama_file::read_u32() { return pimpl->read_u32(); }
void llama_file::write_raw(const void * ptr, size_t len) const { pimpl->write_raw(ptr, len); }
void llama_file::write_u32(uint32_t val) const { pimpl->write_u32(val); }

View File

@ -24,15 +24,16 @@ struct llama_file {
void seek(size_t offset, int whence) const;
void read_raw(void * ptr, size_t len) const;
void read_raw_at(void * ptr, size_t len, size_t offset) const;
void read_aligned_chunk(size_t offset, void * dest, size_t size) const;
uint32_t read_u32() const;
void read_raw(void * ptr, size_t len);
void read_raw_unsafe(void * ptr, size_t len);
void read_aligned_chunk(void * dest, size_t size);
uint32_t read_u32();
void write_raw(const void * ptr, size_t len) const;
void write_u32(uint32_t val) const;
size_t read_alignment() const;
bool has_direct_io() const;
private:
struct impl;
std::unique_ptr<impl> pimpl;

View File

@ -495,6 +495,7 @@ llama_model_loader::llama_model_loader(
const std::string & fname,
std::vector<std::string> & splits,
bool use_mmap,
bool use_direct_io,
bool check_tensors,
bool no_alloc,
const llama_model_kv_override * param_overrides_p,
@ -527,9 +528,17 @@ llama_model_loader::llama_model_loader(
get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false);
llm_kv = LLM_KV(llm_arch_from_string(arch_name));
files.emplace_back(new llama_file(fname.c_str(), "rb", !use_mmap));
files.emplace_back(new llama_file(fname.c_str(), "rb", use_direct_io));
contexts.emplace_back(ctx);
use_direct_io = use_direct_io && files.back()->has_direct_io();
// Disable mmap in case Direct I/O is enabled and available
if (use_direct_io && use_mmap) {
use_mmap = false;
LLAMA_LOG_WARN("%s: direct I/O is enabled, disabling mmap\n", __func__);
}
// Save tensors data offset of the main file.
// For subsidiary files, `meta` tensor data offset must not be used,
// so we build a unified tensors index for weights.
@ -595,7 +604,7 @@ llama_model_loader::llama_model_loader(
}
}
files.emplace_back(new llama_file(fname_split, "rb", !use_mmap));
files.emplace_back(new llama_file(fname_split, "rb", use_direct_io));
contexts.emplace_back(ctx);
// Save tensors data offset info of the shard.
@ -739,6 +748,7 @@ llama_model_loader::llama_model_loader(
}
this->use_mmap = use_mmap;
this->use_direct_io = use_direct_io;
this->check_tensors = check_tensors;
this->no_alloc = no_alloc;
}
@ -1100,7 +1110,8 @@ bool llama_model_loader::load_all_data(
const auto & file = files.at(weight->idx);
if (ggml_backend_buffer_is_host(cur->buffer)) {
file->read_raw_at(cur->data, n_size, weight->offs);
file->seek(weight->offs, SEEK_SET);
file->read_raw(cur->data, n_size);
if (check_tensors) {
validation_result.emplace_back(std::async(std::launch::async, [cur, n_size] {
return std::make_pair(cur, ggml_validate_row_data(cur->type, cur->data, n_size));
@ -1132,7 +1143,7 @@ bool llama_model_loader::load_all_data(
ggml_backend_event_synchronize(events[buffer_idx]);
// Read aligned chunk from file
file->read_raw(reinterpret_cast<void *>(ptr_dest_aligned), read_size);
file->read_raw_unsafe(reinterpret_cast<void *>(ptr_dest_aligned), read_size);
// Calculate actual data portion (excluding alignment padding)
uintptr_t ptr_data = ptr_dest_aligned;
@ -1162,7 +1173,8 @@ bool llama_model_loader::load_all_data(
}
} else {
read_buf.resize(n_size);
file->read_raw_at(read_buf.data(), n_size, weight->offs);
file->seek(weight->offs, SEEK_SET);
file->read_raw(read_buf.data(), n_size);
ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size);
if (check_tensors && !ggml_validate_row_data(cur->type, read_buf.data(), n_size)) {
throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur)));

View File

@ -70,6 +70,7 @@ struct llama_model_loader {
size_t n_bytes = 0;
bool use_mmap = false;
bool use_direct_io = false;
bool check_tensors;
bool no_alloc;
@ -97,6 +98,7 @@ struct llama_model_loader {
const std::string & fname,
std::vector<std::string> & splits, // optional, only need if the split does not follow naming scheme
bool use_mmap,
bool use_direct_io,
bool check_tensors,
bool no_alloc,
const llama_model_kv_override * param_overrides_p,

View File

@ -2473,7 +2473,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
const bool use_mmap_buffer = true;
LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s)\n", __func__, ml.use_mmap ? "true" : "false");
LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s, direct_io = %s)\n",
__func__, ml.use_mmap ? "true" : "false", ml.use_direct_io ? "true" : "false");
// build a list of buffer types for the CPU and GPU devices
pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.use_extra_bufts, params.no_host);
@ -2484,6 +2485,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
pimpl->gpu_buft_list.emplace(dev, std::move(buft_list));
}
ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
if (cpu_dev == nullptr) {
throw std::runtime_error(format("%s: no CPU backend found", __func__));
}
// calculate the split points
bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + n_devices(), [](float x) { return x == 0.0f; });
std::vector<float> splits(n_devices());
@ -2494,6 +2500,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
size_t total;
size_t free;
ggml_backend_dev_memory(dev, &free, &total);
// devices can return 0 bytes for free and total memory if they do not
// have any to report. in this case, we will use the host memory as a fallback
// fixes: https://github.com/ggml-org/llama.cpp/issues/18577
if (free == 0 && total == 0) {
ggml_backend_dev_memory(cpu_dev, &free, &total);
}
splits[i] = free;
}
} else {
@ -2510,10 +2523,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
splits[i] /= split_sum;
}
ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
if (cpu_dev == nullptr) {
throw std::runtime_error(format("%s: no CPU backend found", __func__));
}
const int i_gpu_start = std::max(int(hparams.n_layer) + 1 - n_gpu_layers, 0);
const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, int(n_layer) + 1);
auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev {
@ -8152,6 +8161,7 @@ llama_model_params llama_model_default_params() {
/*.kv_overrides =*/ nullptr,
/*.vocab_only =*/ false,
/*.use_mmap =*/ true,
/*.use_direct_io =*/ true,
/*.use_mlock =*/ false,
/*.check_tensors =*/ false,
/*.use_extra_bufts =*/ true,

View File

@ -596,7 +596,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
}
std::vector<std::string> splits = {};
llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr);
llama_model_loader ml(fname_inp, splits, use_mmap, /*use_direct_io*/ true, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr);
ml.init_mappings(false); // no prefetching
llama_model model(llama_model_default_params());

View File

@ -111,8 +111,20 @@ static std::vector<llama_device_memory_data> llama_get_device_memory_data(
}
}
for (size_t i = 0; i < ret.size(); i++) {
size_t free, total;
size_t free;
size_t total;
ggml_backend_dev_memory(model->devices[i], &free, &total);
// devices can return 0 bytes for free and total memory if they do not
// have any to report. in this case, we will use the host memory as a fallback
// fixes: https://github.com/ggml-org/llama.cpp/issues/18577
if (free == 0 && total == 0) {
ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
if (cpu_dev == nullptr) {
throw std::runtime_error(format("%s: no CPU backend found", __func__));
}
ggml_backend_dev_memory(cpu_dev, &free, &total);
}
ret[i].free = free;
ret[i].total = total;
}
@ -147,9 +159,8 @@ class llama_params_fit_exception : public std::runtime_error {
static void llama_params_fit_impl(
const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams,
float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides,
size_t margin_s, uint32_t n_ctx_min, enum ggml_log_level log_level) {
size_t * margins_s, uint32_t n_ctx_min, enum ggml_log_level log_level) {
constexpr int64_t MiB = 1024*1024;
const int64_t margin = margin_s; // this function uses int64_t rather than size_t for memory sizes to more conveniently handle deficits
typedef std::vector<llama_device_memory_data> dmds_t;
const llama_model_params default_mparams = llama_model_default_params();
@ -168,6 +179,12 @@ static void llama_params_fit_impl(
return;
}
std::vector<int64_t> margins; // this function uses int64_t rather than size_t for memory sizes to more conveniently handle deficits
margins.reserve(nd);
for (size_t id = 0; id < nd; id++) {
margins.push_back(margins_s[id]);
}
std::vector<std::string> dev_names;
{
dev_names.reserve(nd);
@ -187,9 +204,10 @@ static void llama_params_fit_impl(
int64_t sum_free = 0;
int64_t sum_projected_free = 0;
int64_t min_projected_free = INT64_MAX;
int64_t sum_projected_used = 0;
int64_t sum_projected_model = 0;
std::vector<int64_t> projected_free_per_device;
projected_free_per_device.reserve(nd);
if (nd > 1) {
LLAMA_LOG_INFO("%s: projected memory use with initial parameters [MiB]:\n", __func__);
@ -199,45 +217,63 @@ static void llama_params_fit_impl(
const int64_t projected_used = dmd.mb.total();
const int64_t projected_free = dmd.free - projected_used;
projected_free_per_device.push_back(projected_free);
sum_free += dmd.free;
sum_projected_used += projected_used;
sum_projected_free += projected_free;
min_projected_free = std::min(min_projected_free, projected_free);
sum_projected_model += dmd.mb.model;
if (nd > 1) {
LLAMA_LOG_INFO("%s: - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " %s\n",
__func__, dev_names[id].c_str(), dmd.total/MiB, projected_used/MiB, std::abs(projected_free)/MiB,
projected_free >= 0 ? "surplus" : "deficit");
LLAMA_LOG_INFO("%s: - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " free vs. target of %6" PRId64 "\n",
__func__, dev_names[id].c_str(), dmd.total/MiB, projected_used/MiB, projected_free/MiB, margins[id]/MiB);
}
}
assert(sum_free >= 0 && sum_projected_used >= 0);
LLAMA_LOG_INFO("%s: projected to use %" PRId64 " MiB of device memory vs. %" PRId64 " MiB of free device memory\n",
__func__, sum_projected_used/MiB, sum_free/MiB);
if (min_projected_free >= margin) {
if (nd == 1) {
if (nd == 1) {
if (projected_free_per_device[0] >= margins[0]) {
LLAMA_LOG_INFO("%s: will leave %" PRId64 " >= %" PRId64 " MiB of free device memory, no changes needed\n",
__func__, min_projected_free/MiB, margin/MiB);
__func__, projected_free_per_device[0]/MiB, margins[0]/MiB);
return;
}
} else {
bool changes_needed = false;
for (size_t id = 0; id < nd; id++) {
if (projected_free_per_device[id] < margins[id]) {
changes_needed = true;
break;
}
}
if (!changes_needed) {
LLAMA_LOG_INFO("%s: targets for free memory can be met on all devices, no changes needed\n", __func__);
return;
}
LLAMA_LOG_INFO("%s: will leave at least %" PRId64 " >= %" PRId64 " MiB of free memory on all devices, no changes needed\n",
__func__, min_projected_free/MiB, margin/MiB);
return;
}
// step 2: try reducing memory use by reducing the context size
{
int64_t global_surplus = sum_projected_free - int64_t(nd)*margin;
int64_t global_surplus = sum_projected_free;
for (size_t id = 0; id < nd; id++) {
global_surplus -= margins[id];
}
if (global_surplus < 0) {
LLAMA_LOG_INFO(nd == 1 ?
"%s: cannot fulfill margin of %" PRId64 " MiB, need to reduce device memory by %" PRId64 " MiB\n" :
"%s: cannot fulfill margin of %" PRId64 " MiB on all devices, need to use %" PRId64 " MiB less in total\n",
__func__, margin/MiB, -global_surplus/MiB);
if (nd == 1) {
LLAMA_LOG_INFO("%s: cannot meet free memory target of %" PRId64 " MiB, need to reduce device memory by %" PRId64 " MiB\n",
__func__, margins[0]/MiB, -global_surplus/MiB);
} else {
LLAMA_LOG_INFO(
"%s: cannot meet free memory targets on all devices, need to use %" PRId64 " MiB less in total\n",
__func__, -global_surplus/MiB);
}
if (cparams->n_ctx == 0) {
if (hp_nct > n_ctx_min) {
int64_t sum_used_target = sum_free - nd*margin_s;
int64_t sum_used_target = sum_free;
for (size_t id = 0; id < nd; id++) {
sum_used_target -= margins[id];
}
if (nd > 1) {
// for multiple devices we need to be more conservative in terms of how much context we think can fit:
// - for dense models only whole layers can be assigned to devices
@ -448,9 +484,9 @@ static void llama_params_fit_impl(
const dmds_t dmds_cpu_moe = llama_get_device_memory_data(
path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level);
for (const llama_device_memory_data & dmd : dmds_cpu_moe) {
global_surplus_cpu_moe += dmd.free;
global_surplus_cpu_moe -= int64_t(dmd.mb.total()) + margin;
for (size_t id = 0; id < nd; id++) {
global_surplus_cpu_moe += dmds_cpu_moe[id].free;
global_surplus_cpu_moe -= int64_t(dmds_cpu_moe[id].mb.total()) + margins[id];
}
if (global_surplus_cpu_moe > 0) {
@ -469,7 +505,7 @@ static void llama_params_fit_impl(
std::vector<int64_t> targets; // maximum acceptable memory use per device
targets.reserve(nd);
for (size_t id = 0; id < nd; id++) {
targets.push_back(dmds_full[id].free - margin);
targets.push_back(dmds_full[id].free - margins[id]);
LLAMA_LOG_DEBUG("%s: id=%zu, target=%" PRId64 " MiB\n", __func__, id, targets[id]/MiB);
}
@ -701,11 +737,11 @@ static void llama_params_fit_impl(
enum llama_params_fit_status llama_params_fit(
const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams,
float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides,
size_t margin_s, uint32_t n_ctx_min, enum ggml_log_level log_level) {
size_t * margins, uint32_t n_ctx_min, enum ggml_log_level log_level) {
const int64_t t0_us = llama_time_us();
llama_params_fit_status status = LLAMA_PARAMS_FIT_STATUS_SUCCESS;
try {
llama_params_fit_impl(path_model, mparams, cparams, tensor_split, tensor_buft_overrides, margin_s, n_ctx_min, log_level);
llama_params_fit_impl(path_model, mparams, cparams, tensor_split, tensor_buft_overrides, margins, n_ctx_min, log_level);
LLAMA_LOG_INFO("%s: successfully fit params to free device memory\n", __func__);
} catch (const llama_params_fit_exception & e) {
LLAMA_LOG_WARN("%s: failed to fit params to free device memory: %s\n", __func__, e.what());
@ -794,7 +830,7 @@ static int llama_model_load(const std::string & fname, std::vector<std::string>
model.t_start_us = tm.t_start_us;
try {
llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.no_alloc, params.kv_overrides, params.tensor_buft_overrides);
llama_model_loader ml(fname, splits, params.use_mmap, params.use_direct_io, params.check_tensors, params.no_alloc, params.kv_overrides, params.tensor_buft_overrides);
ml.print_info();

View File

@ -1,5 +1,6 @@
#include "arg.h"
#include "common.h"
#include "download.h"
#include <string>
#include <vector>

View File

@ -25,7 +25,6 @@ else()
if (LLAMA_BUILD_SERVER)
add_subdirectory(server)
endif()
add_subdirectory(run)
add_subdirectory(tokenize)
add_subdirectory(tts)
add_subdirectory(mtmd)

View File

@ -27,7 +27,7 @@ int main(int argc, char ** argv) {
auto mparams = common_model_params_to_llama(params);
auto cparams = common_context_params_to_llama(params);
const llama_params_fit_status status = llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target, params.fit_params_min_ctx,
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target.data(), params.fit_params_min_ctx,
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
if (status != LLAMA_PARAMS_FIT_STATUS_SUCCESS) {
LOG_ERR("%s: failed to fit CLI arguments to free memory, exiting...\n", __func__);

View File

@ -1,23 +0,0 @@
set(TARGET llama-run)
add_executable(${TARGET} run.cpp linenoise.cpp/linenoise.cpp)
# TODO: avoid copying this code block from common/CMakeLists.txt
set(LLAMA_RUN_EXTRA_LIBS "")
if (LLAMA_CURL)
find_package(CURL REQUIRED)
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_CURL)
include_directories(${CURL_INCLUDE_DIRS})
set(LLAMA_RUN_EXTRA_LIBS ${LLAMA_RUN_EXTRA_LIBS} ${CURL_LIBRARIES})
endif ()
if(LLAMA_TOOLS_INSTALL)
install(TARGETS ${TARGET} RUNTIME)
endif()
if (CMAKE_SYSTEM_NAME MATCHES "AIX")
# AIX's flock() function comes from libbsd.a
target_link_libraries(${TARGET} PRIVATE -lbsd)
endif()
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT} ${LLAMA_RUN_EXTRA_LIBS})
target_compile_features(${TARGET} PRIVATE cxx_std_17)

View File

@ -1,52 +0,0 @@
# llama.cpp/example/run
The purpose of this example is to demonstrate a minimal usage of llama.cpp for running models.
```bash
llama-run granite3-moe
```
```bash
Description:
Runs a llm
Usage:
llama-run [options] model [prompt]
Options:
-c, --context-size <value>
Context size (default: 2048)
-n, -ngl, --ngl <value>
Number of GPU layers (default: 0)
--temp <value>
Temperature (default: 0.8)
-v, --verbose, --log-verbose
Set verbosity level to infinity (i.e. log all messages, useful for debugging)
-h, --help
Show help message
Commands:
model
Model is a string with an optional prefix of
huggingface:// (hf://), ollama://, https:// or file://.
If no protocol is specified and a file exists in the specified
path, file:// is assumed, otherwise if a file does not exist in
the specified path, ollama:// is assumed. Models that are being
pulled are downloaded with .partial extension while being
downloaded and then renamed as the file without the .partial
extension when complete.
Examples:
llama-run llama3
llama-run ollama://granite-code
llama-run ollama://smollm:135m
llama-run hf://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf
llama-run huggingface://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf
llama-run ms://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf
llama-run modelscope://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf
llama-run https://example.com/some-file1.gguf
llama-run some-file2.gguf
llama-run file://some-file3.gguf
llama-run --ngl 999 some-file4.gguf
llama-run --ngl 999 some-file5.gguf Hello World
```

File diff suppressed because it is too large Load Diff

View File

@ -1,137 +0,0 @@
/* linenoise.h -- VERSION 1.0
*
* Guerrilla line editing library against the idea that a line editing lib
* needs to be 20,000 lines of C++ code.
*
* See linenoise.cpp for more information.
*
* ------------------------------------------------------------------------
*
* Copyright (c) 2010-2023, Salvatore Sanfilippo <antirez at gmail dot com>
* Copyright (c) 2010-2013, Pieter Noordhuis <pcnoordhuis at gmail dot com>
* Copyright (c) 2025, Eric Curtin <ericcurtin17 at gmail dot com>
*
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
*
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifndef __LINENOISE_H
#define __LINENOISE_H
#ifdef __cplusplus
extern "C" {
#endif
#include <stddef.h> /* For size_t. */
#include <stdlib.h>
extern const char * linenoiseEditMore;
/* The linenoiseState structure represents the state during line editing.
* We pass this state to functions implementing specific editing
* functionalities. */
struct linenoiseState {
int in_completion; /* The user pressed TAB and we are now in completion
* mode, so input is handled by completeLine(). */
size_t completion_idx; /* Index of next completion to propose. */
int ifd; /* Terminal stdin file descriptor. */
int ofd; /* Terminal stdout file descriptor. */
char * buf; /* Edited line buffer. */
size_t buflen; /* Edited line buffer size. */
const char * prompt; /* Prompt to display. */
size_t plen; /* Prompt length. */
size_t pos; /* Current cursor position. */
size_t oldcolpos; /* Previous refresh cursor column position. */
size_t len; /* Current edited line length. */
size_t cols; /* Number of columns in terminal. */
size_t oldrows; /* Rows used by last refreshed line (multiline mode) */
int history_index; /* The history index we are currently editing. */
};
struct linenoiseCompletions {
size_t len = 0;
char ** cvec = nullptr;
bool to_free = true;
~linenoiseCompletions() {
if (!to_free) {
return;
}
for (size_t i = 0; i < len; ++i) {
free(cvec[i]);
}
free(cvec);
}
};
/* Non blocking API. */
int linenoiseEditStart(struct linenoiseState * l, int stdin_fd, int stdout_fd, char * buf, size_t buflen,
const char * prompt);
const char * linenoiseEditFeed(struct linenoiseState * l);
void linenoiseEditStop(struct linenoiseState * l);
void linenoiseHide(struct linenoiseState * l);
void linenoiseShow(struct linenoiseState * l);
/* Blocking API. */
const char * linenoise(const char * prompt);
void linenoiseFree(void * ptr);
/* Completion API. */
typedef void(linenoiseCompletionCallback)(const char *, linenoiseCompletions *);
typedef const char *(linenoiseHintsCallback) (const char *, int * color, int * bold);
typedef void(linenoiseFreeHintsCallback)(const char *);
void linenoiseSetCompletionCallback(linenoiseCompletionCallback *);
void linenoiseSetHintsCallback(linenoiseHintsCallback *);
void linenoiseSetFreeHintsCallback(linenoiseFreeHintsCallback *);
void linenoiseAddCompletion(linenoiseCompletions *, const char *);
/* History API. */
int linenoiseHistoryAdd(const char * line);
int linenoiseHistorySetMaxLen(int len);
int linenoiseHistorySave(const char * filename);
int linenoiseHistoryLoad(const char * filename);
/* Other utilities. */
void linenoiseClearScreen(void);
void linenoiseSetMultiLine(int ml);
void linenoisePrintKeyCodes(void);
void linenoiseMaskModeEnable(void);
void linenoiseMaskModeDisable(void);
/* Encoding functions. */
typedef size_t(linenoisePrevCharLen)(const char * buf, size_t buf_len, size_t pos, size_t * col_len);
typedef size_t(linenoiseNextCharLen)(const char * buf, size_t buf_len, size_t pos, size_t * col_len);
typedef size_t(linenoiseReadCode)(int fd, char * buf, size_t buf_len, int * c);
void linenoiseSetEncodingFunctions(linenoisePrevCharLen * prevCharLenFunc, linenoiseNextCharLen * nextCharLenFunc,
linenoiseReadCode * readCodeFunc);
#ifdef __cplusplus
}
#endif
#endif /* __LINENOISE_H */

File diff suppressed because it is too large Load Diff

View File

@ -1,10 +1,10 @@
#include "common.h"
#include "download.h"
#include "log.h"
#include "llama.h"
#include "mtmd.h"
#include "mtmd-helper.h"
#include "chat.h"
#include "arg.h" // for common_remote_get_content; TODO: use download.h only
#include "base64.hpp"
#include "server-common.h"
@ -779,7 +779,7 @@ static void handle_media(
// download remote image
// TODO @ngxson : maybe make these params configurable
common_remote_params params;
params.headers.push_back("User-Agent: llama.cpp/" + build_info);
params.headers.push_back({"User-Agent", "llama.cpp/" + build_info});
params.max_size = 1024 * 1024 * 10; // 10MB
params.timeout = 10; // seconds
SRV_INF("downloading image from '%s'\n", url.c_str());

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff