Merge remote-tracking branch 'ggml-org/master' into allozaur/mcp-mvp

This commit is contained in:
Aleksander Grygier 2026-01-29 13:21:44 +01:00
commit 9d6e210a5e
152 changed files with 12319 additions and 2796 deletions

View File

@ -19,7 +19,7 @@ on:
jobs:
check-vendor:
runs-on: ubuntu-latest
runs-on: ubuntu-slim
steps:
- name: Checkout

View File

@ -10,7 +10,7 @@ permissions:
jobs:
close-issues:
runs-on: ubuntu-latest
runs-on: ubuntu-slim
permissions:
issues: write
pull-requests: write

View File

@ -20,7 +20,7 @@ concurrency:
jobs:
editorconfig:
runs-on: ubuntu-latest
runs-on: ubuntu-slim
steps:
- uses: actions/checkout@v6
- uses: editorconfig-checker/action-editorconfig-checker@v2

View File

@ -21,7 +21,7 @@ on:
jobs:
deploy:
runs-on: ubuntu-latest
runs-on: ubuntu-slim
steps:
- uses: actions/checkout@v6

View File

@ -7,7 +7,7 @@ jobs:
permissions:
contents: read
pull-requests: write
runs-on: ubuntu-latest
runs-on: ubuntu-slim
steps:
- uses: actions/checkout@v6
with:

View File

@ -12,7 +12,7 @@ on:
jobs:
pre-tokenizer-hashes:
runs-on: ubuntu-latest
runs-on: ubuntu-slim
steps:
- name: Checkout repository

View File

@ -20,7 +20,7 @@ concurrency:
jobs:
python-check-requirements:
runs-on: ubuntu-latest
runs-on: ubuntu-slim
name: check-requirements
steps:
- name: Check out source repository

View File

@ -15,7 +15,7 @@ concurrency:
jobs:
flake8-lint:
runs-on: ubuntu-latest
runs-on: ubuntu-slim
name: Lint
steps:
- name: Check out source repository

View File

@ -29,9 +29,7 @@ jobs:
uses: actions/setup-python@v6
with:
python-version: "3.11"
- name: Install Python dependencies
# TODO: use a venv
run: pip install -r requirements/requirements-all.txt
pip-install: -r requirements/requirements-all.txt
- name: Type-check with Pyright
uses: jakebailey/pyright-action@v2
with:

View File

@ -14,7 +14,7 @@ on:
jobs:
update-ops-docs:
runs-on: ubuntu-latest
runs-on: ubuntu-slim
steps:
- name: Checkout repository

View File

@ -28,16 +28,17 @@ jobs:
owner: context.repo.owner,
repo: context.repo.repo,
});
console.log("Latest release:", releases[0].tag_name);
return releases[0].tag_name;
const { tag_name: version, assets: assets } = releases.find(({assets}) => assets.find(asset => asset.name.includes('win-vulkan')));
const { browser_download_url: asset_url } = assets.find(asset => asset.name.includes('win-vulkan'));
console.log("Latest release:", version);
core.setOutput('VERSION', version);
core.setOutput('ASSETURL', asset_url);
- name: Update manifest
env:
VERSION: ${{ steps.find_latest_release.outputs.result }}
run: |
echo "Updating manifest..."
komac update --version ${{ env.VERSION }} \
--urls "https://github.com/ggml-org/llama.cpp/releases/download/${{ env.VERSION }}/llama-${{ env.VERSION }}-bin-win-vulkan-x64.zip" \
komac update --version ${{ steps.find_latest_release.outputs.VERSION }} \
--urls "${{ steps.find_latest_release.outputs.ASSETURL }}" \
--token ${{ secrets.WINGET_GITHUB_TOKEN }} \
--submit \
ggml.llamacpp

View File

@ -18,6 +18,7 @@
/common/jinja/ @ngxson @CISC @aldehir
/common/llguidance.* @ggerganov
/common/log.* @ggerganov
/common/ngram-map.* @srogmann
/common/peg-parser.* @aldehir
/common/sampling.* @ggerganov
/common/speculative.* @ggerganov
@ -67,6 +68,7 @@
/ggml/src/ggml-rpc/ @rgerganov
/ggml/src/ggml-threading.* @ggerganov
/ggml/src/ggml-vulkan/ @0cc4m
/ggml/src/ggml-virtgpu/ @kpouget
/ggml/src/ggml-webgpu/ @reeselevine
/ggml/src/ggml-zdnn/ @taronaeo @Andreas-Krebbel @AlekseiNikiforovIBM
/ggml/src/ggml.c @ggerganov

View File

@ -132,6 +132,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
- [x] [FalconMamba Models](https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a)
- [x] [Jais](https://huggingface.co/inceptionai/jais-13b-chat)
- [x] [Bielik-11B-v2.3](https://huggingface.co/collections/speakleash/bielik-11b-v23-66ee813238d9b526a072408a)
- [x] [RWKV-7](https://huggingface.co/collections/shoumenchougou/rwkv7-gxx-gguf)
- [x] [RWKV-6](https://github.com/BlinkDL/RWKV-LM)
- [x] [QRWKV-6](https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1)
- [x] [GigaChat-20B-A3B](https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct)

View File

@ -73,6 +73,8 @@ add_library(${TARGET} STATIC
log.h
ngram-cache.cpp
ngram-cache.h
ngram-map.cpp
ngram-map.h
peg-parser.cpp
peg-parser.h
preset.cpp

View File

@ -6,6 +6,7 @@
#include "json-schema-to-grammar.h"
#include "log.h"
#include "sampling.h"
#include "speculative.h"
#include "preset.h"
// fix problem with std::min and std::max
@ -579,14 +580,14 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
params.mmproj = res.mmproj;
}
// only download mmproj if the current example is using it
for (auto & ex : mmproj_examples) {
for (const auto & ex : mmproj_examples) {
if (ctx_arg.ex == ex) {
common_params_handle_model(params.mmproj, params.hf_token, params.offline);
break;
}
}
common_params_handle_model(params.speculative.model, params.hf_token, params.offline);
common_params_handle_model(params.vocoder.model, params.hf_token, params.offline);
common_params_handle_model(params.speculative.mparams_dft, params.hf_token, params.offline);
common_params_handle_model(params.vocoder.model, params.hf_token, params.offline);
}
// model is required (except for server)
@ -1216,21 +1217,25 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"-lcs", "--lookup-cache-static"}, "FNAME",
"path to static lookup cache to use for lookup decoding (not updated by generation)",
[](common_params & params, const std::string & value) {
params.lookup_cache_static = value;
params.speculative.lookup_cache_static = value;
}
).set_examples({LLAMA_EXAMPLE_LOOKUP}));
).set_examples({LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-lcd", "--lookup-cache-dynamic"}, "FNAME",
"path to dynamic lookup cache to use for lookup decoding (updated by generation)",
[](common_params & params, const std::string & value) {
params.lookup_cache_dynamic = value;
params.speculative.lookup_cache_dynamic = value;
}
).set_examples({LLAMA_EXAMPLE_LOOKUP}));
).set_examples({LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-c", "--ctx-size"}, "N",
string_format("size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx),
[](common_params & params, int value) {
params.n_ctx = value;
if (value == 0) {
// disable context reduction in llama_params_fit if the user explicitly requests the full context size:
params.fit_params_min_ctx = UINT32_MAX;
}
}
).set_env("LLAMA_ARG_CTX_SIZE"));
add_opt(common_arg(
@ -1291,11 +1296,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_env("LLAMA_ARG_CACHE_RAM").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
add_opt(common_arg(
{"-kvu", "--kv-unified"},
{"-no-kvu", "--no-kv-unified"},
"use single unified KV buffer shared across all sequences (default: enabled if number of slots is auto)",
[](common_params & params) {
params.kv_unified = true;
[](common_params & params, bool value) {
params.kv_unified = value;
}
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED}));
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH}));
add_opt(common_arg(
{"--context-shift"},
{"--no-context-shift"},
@ -1573,7 +1579,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_sparam());
add_opt(common_arg(
{"--temp"}, "N",
string_format("temperature (default: %.1f)", (double)params.sampling.temp),
string_format("temperature (default: %.2f)", (double)params.sampling.temp),
[](common_params & params, const std::string & value) {
params.sampling.temp = std::stof(value);
params.sampling.temp = std::max(params.sampling.temp, 0.0f);
@ -1590,7 +1596,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_sparam().set_env("LLAMA_ARG_TOP_K"));
add_opt(common_arg(
{"--top-p"}, "N",
string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p),
string_format("top-p sampling (default: %.2f, 1.0 = disabled)", (double)params.sampling.top_p),
[](common_params & params, const std::string & value) {
params.sampling.top_p = std::stof(value);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P;
@ -1598,7 +1604,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_sparam());
add_opt(common_arg(
{"--min-p"}, "N",
string_format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sampling.min_p),
string_format("min-p sampling (default: %.2f, 0.0 = disabled)", (double)params.sampling.min_p),
[](common_params & params, const std::string & value) {
params.sampling.min_p = std::stof(value);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P;
@ -1606,14 +1612,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_sparam());
add_opt(common_arg(
{"--top-nsigma"}, "N",
string_format("top-n-sigma sampling (default: %.1f, -1.0 = disabled)", params.sampling.top_n_sigma),
string_format("top-n-sigma sampling (default: %.2f, -1.0 = disabled)", params.sampling.top_n_sigma),
[](common_params & params, const std::string & value) {
params.sampling.top_n_sigma = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--xtc-probability"}, "N",
string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability),
string_format("xtc probability (default: %.2f, 0.0 = disabled)", (double)params.sampling.xtc_probability),
[](common_params & params, const std::string & value) {
params.sampling.xtc_probability = std::stof(value);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY;
@ -1621,7 +1627,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_sparam());
add_opt(common_arg(
{"--xtc-threshold"}, "N",
string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sampling.xtc_threshold),
string_format("xtc threshold (default: %.2f, 1.0 = disabled)", (double)params.sampling.xtc_threshold),
[](common_params & params, const std::string & value) {
params.sampling.xtc_threshold = std::stof(value);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD;
@ -1629,7 +1635,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_sparam());
add_opt(common_arg(
{"--typical"}, "N",
string_format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sampling.typ_p),
string_format("locally typical sampling, parameter p (default: %.2f, 1.0 = disabled)", (double)params.sampling.typ_p),
[](common_params & params, const std::string & value) {
params.sampling.typ_p = std::stof(value);
}
@ -1648,7 +1654,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_sparam());
add_opt(common_arg(
{"--repeat-penalty"}, "N",
string_format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sampling.penalty_repeat),
string_format("penalize repeat sequence of tokens (default: %.2f, 1.0 = disabled)", (double)params.sampling.penalty_repeat),
[](common_params & params, const std::string & value) {
params.sampling.penalty_repeat = std::stof(value);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT;
@ -1656,21 +1662,21 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_sparam());
add_opt(common_arg(
{"--presence-penalty"}, "N",
string_format("repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)params.sampling.penalty_present),
string_format("repeat alpha presence penalty (default: %.2f, 0.0 = disabled)", (double)params.sampling.penalty_present),
[](common_params & params, const std::string & value) {
params.sampling.penalty_present = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--frequency-penalty"}, "N",
string_format("repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)", (double)params.sampling.penalty_freq),
string_format("repeat alpha frequency penalty (default: %.2f, 0.0 = disabled)", (double)params.sampling.penalty_freq),
[](common_params & params, const std::string & value) {
params.sampling.penalty_freq = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--dry-multiplier"}, "N",
string_format("set DRY sampling multiplier (default: %.1f, 0.0 = disabled)", (double)params.sampling.dry_multiplier),
string_format("set DRY sampling multiplier (default: %.2f, 0.0 = disabled)", (double)params.sampling.dry_multiplier),
[](common_params & params, const std::string & value) {
params.sampling.dry_multiplier = std::stof(value);
}
@ -1751,14 +1757,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_sparam());
add_opt(common_arg(
{"--dynatemp-range"}, "N",
string_format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sampling.dynatemp_range),
string_format("dynamic temperature range (default: %.2f, 0.0 = disabled)", (double)params.sampling.dynatemp_range),
[](common_params & params, const std::string & value) {
params.sampling.dynatemp_range = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--dynatemp-exp"}, "N",
string_format("dynamic temperature exponent (default: %.1f)", (double)params.sampling.dynatemp_exponent),
string_format("dynamic temperature exponent (default: %.2f)", (double)params.sampling.dynatemp_exponent),
[](common_params & params, const std::string & value) {
params.sampling.dynatemp_exponent = std::stof(value);
}
@ -1774,7 +1780,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_sparam());
add_opt(common_arg(
{"--mirostat-lr"}, "N",
string_format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sampling.mirostat_eta),
string_format("Mirostat learning rate, parameter eta (default: %.2f)", (double)params.sampling.mirostat_eta),
[](common_params & params, const std::string & value) {
params.sampling.mirostat_eta = std::stof(value);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA;
@ -1782,7 +1788,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_sparam());
add_opt(common_arg(
{"--mirostat-ent"}, "N",
string_format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sampling.mirostat_tau),
string_format("Mirostat target entropy, parameter tau (default: %.2f)", (double)params.sampling.mirostat_tau),
[](common_params & params, const std::string & value) {
params.sampling.mirostat_tau = std::stof(value);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU;
@ -1916,28 +1922,28 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_env("LLAMA_ARG_YARN_ORIG_CTX"));
add_opt(common_arg(
{"--yarn-ext-factor"}, "N",
string_format("YaRN: extrapolation mix factor (default: %.1f, 0.0 = full interpolation)", (double)params.yarn_ext_factor),
string_format("YaRN: extrapolation mix factor (default: %.2f, 0.0 = full interpolation)", (double)params.yarn_ext_factor),
[](common_params & params, const std::string & value) {
params.yarn_ext_factor = std::stof(value);
}
).set_env("LLAMA_ARG_YARN_EXT_FACTOR"));
add_opt(common_arg(
{"--yarn-attn-factor"}, "N",
string_format("YaRN: scale sqrt(t) or attention magnitude (default: %.1f)", (double)params.yarn_attn_factor),
string_format("YaRN: scale sqrt(t) or attention magnitude (default: %.2f)", (double)params.yarn_attn_factor),
[](common_params & params, const std::string & value) {
params.yarn_attn_factor = std::stof(value);
}
).set_env("LLAMA_ARG_YARN_ATTN_FACTOR"));
add_opt(common_arg(
{"--yarn-beta-slow"}, "N",
string_format("YaRN: high correction dim or alpha (default: %.1f)", (double)params.yarn_beta_slow),
string_format("YaRN: high correction dim or alpha (default: %.2f)", (double)params.yarn_beta_slow),
[](common_params & params, const std::string & value) {
params.yarn_beta_slow = std::stof(value);
}
).set_env("LLAMA_ARG_YARN_BETA_SLOW"));
add_opt(common_arg(
{"--yarn-beta-fast"}, "N",
string_format("YaRN: low correction dim or beta (default: %.1f)", (double)params.yarn_beta_fast),
string_format("YaRN: low correction dim or beta (default: %.2f)", (double)params.yarn_beta_fast),
[](common_params & params, const std::string & value) {
params.yarn_beta_fast = std::stof(value);
}
@ -2194,18 +2200,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
add_opt(common_arg(
{"--mmap"},
{"--no-mmap"},
string_format("whether to memory-map model. Explicitly enabling mmap disables direct-io. (if mmap disabled, slower load but may reduce pageouts if not using mlock) (default: %s)", params.use_mmap ? "enabled" : "disabled"),
string_format("whether to memory-map model. (if mmap disabled, slower load but may reduce pageouts if not using mlock) (default: %s)", params.use_mmap ? "enabled" : "disabled"),
[](common_params & params, bool value) {
params.use_mmap = value;
if (value) {
params.use_direct_io = false; // disable direct io when mmap is explicitly enabled
}
}
).set_env("LLAMA_ARG_MMAP"));
add_opt(common_arg(
{"-dio", "--direct-io"},
{"-ndio", "--no-direct-io"},
string_format("use DirectIO if available. Takes precedence over --mmap (default: %s)", params.use_direct_io ? "enabled" : "disabled"),
string_format("use DirectIO if available. (default: %s)", params.use_direct_io ? "enabled" : "disabled"),
[](common_params & params, bool value) {
params.use_direct_io = value;
}
@ -2561,7 +2564,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"-hfd", "-hfrd", "--hf-repo-draft"}, "<user>/<model>[:quant]",
"Same as --hf-repo, but for the draft model (default: unused)",
[](common_params & params, const std::string & value) {
params.speculative.model.hf_repo = value;
params.speculative.mparams_dft.hf_repo = value;
}
).set_env("LLAMA_ARG_HFD_REPO"));
add_opt(common_arg(
@ -3331,14 +3334,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_DRAFT_MIN"));
add_opt(common_arg(
{"--draft-p-split"}, "P",
string_format("speculative decoding split probability (default: %.1f)", (double)params.speculative.p_split),
string_format("speculative decoding split probability (default: %.2f)", (double)params.speculative.p_split),
[](common_params & params, const std::string & value) {
params.speculative.p_split = std::stof(value);
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}).set_env("LLAMA_ARG_DRAFT_P_SPLIT"));
add_opt(common_arg(
{"--draft-p-min"}, "P",
string_format("minimum speculative decoding probability (greedy) (default: %.1f)", (double)params.speculative.p_min),
string_format("minimum speculative decoding probability (greedy) (default: %.2f)", (double)params.speculative.p_min),
[](common_params & params, const std::string & value) {
params.speculative.p_min = std::stof(value);
}
@ -3382,7 +3385,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"-md", "--model-draft"}, "FNAME",
"draft model for speculative decoding (default: unused)",
[](common_params & params, const std::string & value) {
params.speculative.model.path = value;
params.speculative.mparams_dft.path = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_MODEL_DRAFT"));
add_opt(common_arg(
@ -3392,6 +3395,66 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.replacements.push_back({ tgt, dft });
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
add_opt(common_arg(
{"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v]",
string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n",
common_speculative_type_to_str(params.speculative.type).c_str()),
[](common_params & params, const std::string & value) {
if (value == "none") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
} else if (value == "ngram-cache") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_CACHE;
} else if (value == "ngram-simple") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE;
} else if (value == "ngram-map-k") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K;
} else if (value == "ngram-map-k4v") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V;
} else {
throw std::invalid_argument("unknown speculative decoding type without draft model");
}
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--spec-ngram-size-n"}, "N",
string_format("ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: %d)", params.speculative.ngram_size_n),
[](common_params & params, int value) {
if (value < 1 || value > 1024) {
throw std::invalid_argument("ngram size N must be between 1 and 1024 inclusive");
}
params.speculative.ngram_size_n = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--spec-ngram-size-m"}, "N",
string_format("ngram size M for ngram-simple/ngram-map speculative decoding, length of draft m-gram (default: %d)", params.speculative.ngram_size_m),
[](common_params & params, int value) {
if (value < 1 || value > 1024) {
throw std::invalid_argument("ngram size M must be between 1 and 1024 inclusive");
}
params.speculative.ngram_size_m = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--spec-ngram-check-rate"}, "N",
string_format("ngram check rate for ngram-simple/ngram-map speculative decoding (default: %d)", params.speculative.ngram_check_rate),
[](common_params & params, int value) {
if (value < 1) {
throw std::invalid_argument("ngram check rate must be at least 1");
}
params.speculative.ngram_check_rate = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--spec-ngram-min-hits"}, "N",
string_format("minimum hits for ngram-map speculative decoding (default: %d)", params.speculative.ngram_min_hits),
[](common_params & params, int value) {
if (value < 1) {
throw std::invalid_argument("ngram min hits must be at least 1");
}
params.speculative.ngram_min_hits = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-ctkd", "--cache-type-k-draft"}, "TYPE",
string_format(
@ -3618,8 +3681,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params) {
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF";
params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf";
params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
params.speculative.mparams_dft.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
params.speculative.mparams_dft.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
params.port = 8012;
params.n_ubatch = 1024;
params.n_batch = 1024;
@ -3634,8 +3697,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params) {
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-14B-Q8_0-GGUF";
params.model.hf_file = "qwen2.5-coder-14b-q8_0.gguf";
params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
params.speculative.mparams_dft.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
params.speculative.mparams_dft.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
params.port = 8012;
params.n_ubatch = 1024;
params.n_batch = 1024;

View File

@ -2659,6 +2659,10 @@ static common_chat_params common_chat_params_init_translate_gemma(const common_c
templates_params inputs_new = inputs;
json & messages = inputs_new.messages;
// default to chat_template_kwargs, or en-GB if not specified
std::string default_src_lang = inputs.extra_context.value("source_lang_code", "en-GB");
std::string default_tgt_lang = inputs.extra_context.value("target_lang_code", "en-GB");
GGML_ASSERT(messages.is_array());
for (auto & message : messages) {
if (message.contains("role") && message["role"].get<std::string>() != "user") {
@ -2670,8 +2674,10 @@ static common_chat_params common_chat_params_init_translate_gemma(const common_c
if (message.contains("content") && !message["content"].is_array()) {
auto content_str = message["content"].get<std::string>();
// default to en-GB if not specified (to make common_chat_format_example works)
auto src_lang = message.contains("source_lang_code") ? message["source_lang_code"].get<std::string>() : "en-GB";
auto tgt_lang = message.contains("target_lang_code") ? message["target_lang_code"].get<std::string>() : "en-GB";
auto src_lang = message.contains("source_lang_code")
? message["source_lang_code"].get<std::string>() : default_src_lang;
auto tgt_lang = message.contains("target_lang_code")
? message["target_lang_code"].get<std::string>() : default_tgt_lang;
message["content"] = json::array({
json{
{"type", "text"},

View File

@ -1097,7 +1097,10 @@ common_init_result::common_init_result(common_params & params) :
if (params.fit_params) {
LOG_INF("%s: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on\n", __func__);
llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target.data(), params.fit_params_min_ctx,
params.tensor_split,
params.tensor_buft_overrides.data(),
params.fit_params_target.data(),
params.fit_params_min_ctx,
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
}
@ -1208,10 +1211,6 @@ std::vector<llama_adapter_lora_ptr> & common_init_result::lora() {
return pimpl->lora;
}
void common_init_result::free_context() {
pimpl->context.reset();
}
common_init_result_ptr common_init_from_params(common_params & params) {
common_init_result_ptr res(new common_init_result(params));

View File

@ -164,6 +164,16 @@ enum common_params_sampling_config : uint64_t {
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA = 1 << 11,
};
enum common_speculative_type {
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
COMMON_SPECULATIVE_TYPE_DRAFT, // draft model
COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, // self-speculative decoding with 3-level n-gram cache
COMMON_SPECULATIVE_TYPE_COUNT // number of types, unknown type
};
// sampling parameters
struct common_params_sampling {
@ -243,16 +253,35 @@ struct common_params_model {
};
struct common_params_speculative {
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE; // type of speculative decoding
int32_t n_ctx = 0; // draft context size
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
// general-purpose speculative decoding parameters
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
// ngram-based speculative decoding
uint16_t ngram_size_n = 12; // ngram size for lookup
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
uint16_t ngram_check_rate = 1; // check rate for ngram lookup
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
std::string lookup_cache_static; // path of static ngram cache file for lookup decoding // NOLINT
std::string lookup_cache_dynamic; // path of dynamic ngram cache file for lookup decoding // NOLINT
// draft-model speculative decoding
struct common_params_model mparams_dft;
llama_model * model_dft = nullptr; // a llama_model that can be shared by multiple speculative contexts
llama_context_params cparams_dft; // these are the parameters for the draft llama_context
int32_t n_ctx = 0; // draft context size
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
@ -260,7 +289,14 @@ struct common_params_speculative {
struct cpu_params cpuparams;
struct cpu_params cpuparams_batch;
struct common_params_model model;
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
bool has_dft() const {
return !mparams_dft.path.empty() || !mparams_dft.hf_repo.empty();
}
};
struct common_params_vocoder {
@ -378,8 +414,6 @@ struct common_params {
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state // NOLINT
std::string input_prefix = ""; // string to prefix user inputs with // NOLINT
std::string input_suffix = ""; // string to suffix user inputs with // NOLINT
std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT
std::string logits_file = ""; // file for saving *all* logits // NOLINT
// llama-debug specific options
@ -438,7 +472,7 @@ struct common_params {
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool use_mmap = true; // enable mmap to use filesystem cache
bool use_direct_io = true; // read from disk without buffering for faster model loading
bool use_direct_io = false; // read from disk without buffering
bool use_mlock = false; // use mlock to keep model in memory
bool verbose_prompt = false; // print prompt tokens before generation
bool display_prompt = true; // print prompt before generation
@ -575,10 +609,6 @@ struct common_params {
// return false from callback to abort model loading or true to continue
llama_progress_callback load_progress_callback = NULL;
void * load_progress_callback_user_data = NULL;
bool has_speculative() const {
return !speculative.model.path.empty() || !speculative.model.hf_repo.empty();
}
};
// call once at the start of a program if it uses libcommon
@ -714,8 +744,6 @@ struct common_init_result {
std::vector<llama_adapter_lora_ptr> & lora();
void free_context();
private:
struct impl;
std::unique_ptr<impl> pimpl;

View File

@ -60,10 +60,10 @@ static std::pair<httplib::Client, common_http_url> common_http_client(const std:
#ifndef CPPHTTPLIB_OPENSSL_SUPPORT
if (parts.scheme == "https") {
throw std::runtime_error(
"HTTPS is not supported. Please rebuild with:\n"
"HTTPS is not supported. Please rebuild with one of:\n"
" -DLLAMA_BUILD_BORINGSSL=ON\n"
" -DLLAMA_BUILD_LIBRESSL=ON\n"
"or ensure dev files of an OpenSSL-compatible library are available when building."
" -DLLAMA_OPENSSL=ON (default, requires OpenSSL dev files installed)"
);
}
#endif

View File

@ -44,6 +44,12 @@ static std::string get_line_col(const std::string & source, size_t pos) {
return "line " + std::to_string(line) + ", column " + std::to_string(col);
}
static void ensure_key_type_allowed(const value & val) {
if (!val->is_hashable()) {
throw std::runtime_error("Type: " + val->type() + " is not allowed as object key");
}
}
// execute with error handling
value statement::execute(context & ctx) {
try {
@ -95,20 +101,10 @@ value identifier::execute_impl(context & ctx) {
value object_literal::execute_impl(context & ctx) {
auto obj = mk_val<value_object>();
for (const auto & pair : val) {
value key_val = pair.first->execute(ctx);
if (!is_val<value_string>(key_val) && !is_val<value_int>(key_val)) {
throw std::runtime_error("Object literal: keys must be string or int values, got " + key_val->type());
}
std::string key = key_val->as_string().str();
value key = pair.first->execute(ctx);
value val = pair.second->execute(ctx);
JJ_DEBUG("Object literal: setting key '%s' with value type %s", key.c_str(), val->type().c_str());
JJ_DEBUG("Object literal: setting key '%s' with value type %s", key->as_string().str().c_str(), val->type().c_str());
obj->insert(key, val);
if (is_val<value_int>(key_val)) {
obj->val_obj.is_key_numeric = true;
} else if (obj->val_obj.is_key_numeric) {
throw std::runtime_error("Object literal: cannot mix numeric and non-numeric keys");
}
}
return obj;
}
@ -127,9 +123,9 @@ value binary_expression::execute_impl(context & ctx) {
value right_val = right->execute(ctx);
JJ_DEBUG("Executing binary expression %s '%s' %s", left_val->type().c_str(), op.value.c_str(), right_val->type().c_str());
if (op.value == "==") {
return mk_val<value_bool>(value_compare(left_val, right_val, value_compare_op::eq));
return mk_val<value_bool>(*left_val == *right_val);
} else if (op.value == "!=") {
return mk_val<value_bool>(!value_compare(left_val, right_val, value_compare_op::eq));
return mk_val<value_bool>(!(*left_val == *right_val));
}
auto workaround_concat_null_with_str = [&](value & res) -> bool {
@ -230,7 +226,7 @@ value binary_expression::execute_impl(context & ctx) {
auto & arr = right_val->as_array();
bool member = false;
for (const auto & item : arr) {
if (value_compare(left_val, item, value_compare_op::eq)) {
if (*left_val == *item) {
member = true;
break;
}
@ -265,10 +261,9 @@ value binary_expression::execute_impl(context & ctx) {
}
}
// String in object
if (is_val<value_string>(left_val) && is_val<value_object>(right_val)) {
auto key = left_val->as_string().str();
bool has_key = right_val->has_key(key);
// Value key in object
if (is_val<value_object>(right_val)) {
bool has_key = right_val->has_key(left_val);
if (op.value == "in") {
return mk_val<value_bool>(has_key);
} else if (op.value == "not in") {
@ -465,14 +460,8 @@ value for_statement::execute_impl(context & ctx) {
JJ_DEBUG("%s", "For loop over object keys");
auto & obj = iterable_val->as_ordered_object();
for (auto & p : obj) {
auto tuple = mk_val<value_array>();
if (iterable_val->val_obj.is_key_numeric) {
tuple->push_back(mk_val<value_int>(std::stoll(p.first)));
} else {
tuple->push_back(mk_val<value_string>(p.first));
}
tuple->push_back(p.second);
items.push_back(tuple);
auto tuple = mk_val<value_tuple>(p);
items.push_back(std::move(tuple));
}
if (ctx.is_get_stats) {
iterable_val->stats.used = true;
@ -602,11 +591,13 @@ value set_statement::execute_impl(context & ctx) {
auto rhs = val ? val->execute(ctx) : exec_statements(body, ctx);
if (is_stmt<identifier>(assignee)) {
// case: {% set my_var = value %}
auto var_name = cast_stmt<identifier>(assignee)->val;
JJ_DEBUG("Setting global variable '%s' with value type %s", var_name.c_str(), rhs->type().c_str());
ctx.set_val(var_name, rhs);
} else if (is_stmt<tuple_literal>(assignee)) {
// case: {% set a, b = value %}
auto tuple = cast_stmt<tuple_literal>(assignee);
if (!is_val<value_array>(rhs)) {
throw std::runtime_error("Cannot unpack non-iterable type in set: " + rhs->type());
@ -625,6 +616,7 @@ value set_statement::execute_impl(context & ctx) {
}
} else if (is_stmt<member_expression>(assignee)) {
// case: {% set ns.my_var = value %}
auto member = cast_stmt<member_expression>(assignee);
if (member->computed) {
throw std::runtime_error("Cannot assign to computed member");
@ -767,22 +759,22 @@ value member_expression::execute_impl(context & ctx) {
}
JJ_DEBUG("Member expression on object type %s, property type %s", object->type().c_str(), property->type().c_str());
ensure_key_type_allowed(property);
value val = mk_val<value_undefined>("object_property");
if (is_val<value_undefined>(object)) {
JJ_DEBUG("%s", "Accessing property on undefined object, returning undefined");
return val;
} else if (is_val<value_object>(object)) {
if (!is_val<value_string>(property)) {
throw std::runtime_error("Cannot access object with non-string: got " + property->type());
}
auto key = property->as_string().str();
val = object->at(key, val);
val = object->at(property, val);
if (is_val<value_undefined>(val)) {
val = try_builtin_func(ctx, key, object, true);
}
JJ_DEBUG("Accessed property '%s' value, got type: %s", key.c_str(), val->type().c_str());
} else if (is_val<value_array>(object) || is_val<value_string>(object)) {
if (is_val<value_int>(property)) {
int64_t index = property->as_int();
@ -806,6 +798,7 @@ value member_expression::execute_impl(context & ctx) {
auto key = property->as_string().str();
JJ_DEBUG("Accessing %s built-in '%s'", is_val<value_array>(object) ? "array" : "string", key.c_str());
val = try_builtin_func(ctx, key, object, true);
} else {
throw std::runtime_error("Cannot access property with non-string/non-number: got " + property->type());
}

View File

@ -79,18 +79,18 @@ struct context {
}
value get_val(const std::string & name) {
auto it = env->val_obj.unordered.find(name);
if (it != env->val_obj.unordered.end()) {
return it->second;
} else {
return mk_val<value_undefined>(name);
}
value default_val = mk_val<value_undefined>(name);
return env->at(name, default_val);
}
void set_val(const std::string & name, const value & val) {
env->insert(name, val);
}
void set_val(const value & name, const value & val) {
env->insert(name, val);
}
void print_vars() const {
printf("Context Variables:\n%s\n", value_to_json(env, 2).c_str());
}
@ -344,9 +344,19 @@ struct array_literal : public expression {
}
};
struct tuple_literal : public array_literal {
explicit tuple_literal(statements && val) : array_literal(std::move(val)) {}
struct tuple_literal : public expression {
statements val;
explicit tuple_literal(statements && val) : val(std::move(val)) {
for (const auto& item : this->val) chk_type<expression>(item);
}
std::string type() const override { return "TupleLiteral"; }
value execute_impl(context & ctx) override {
auto arr = mk_val<value_array>();
for (const auto & item_stmt : val) {
arr->push_back(item_stmt->execute(ctx));
}
return mk_val<value_tuple>(std::move(arr->as_array()));
}
};
struct object_literal : public expression {

View File

@ -61,6 +61,12 @@ size_t string::length() const {
return len;
}
void string::hash_update(hasher & hash) const noexcept {
for (const auto & part : parts) {
hash.update(part.val.data(), part.val.length());
}
}
bool string::all_parts_are_input() const {
for (const auto & part : parts) {
if (!part.is_input) {

View File

@ -4,6 +4,8 @@
#include <string>
#include <vector>
#include "utils.h"
namespace jinja {
// allow differentiate between user input strings and template strings
@ -37,6 +39,7 @@ struct string {
std::string str() const;
size_t length() const;
void hash_update(hasher & hash) const noexcept;
bool all_parts_are_input() const;
bool is_uppercase() const;
bool is_lowercase() const;

View File

@ -3,6 +3,8 @@
#include <string>
#include <sstream>
#include <algorithm>
#include <cstdint>
#include <cstring>
namespace jinja {
@ -46,4 +48,102 @@ static std::string fmt_error_with_source(const std::string & tag, const std::str
return oss.str();
}
// Note: this is a simple hasher, not cryptographically secure, just for hash table usage
struct hasher {
static constexpr auto size_t_digits = sizeof(size_t) * 8;
static constexpr size_t prime = size_t_digits == 64 ? 0x100000001b3 : 0x01000193;
static constexpr size_t seed = size_t_digits == 64 ? 0xcbf29ce484222325 : 0x811c9dc5;
static constexpr auto block_size = sizeof(size_t); // in bytes; allowing the compiler to vectorize the computation
static_assert(size_t_digits == 64 || size_t_digits == 32);
static_assert(block_size == 8 || block_size == 4);
uint8_t buffer[block_size];
size_t idx = 0; // current index in buffer
size_t state = seed;
hasher() = default;
hasher(const std::type_info & type_inf) noexcept {
const auto type_hash = type_inf.hash_code();
update(&type_hash, sizeof(type_hash));
}
// Properties:
// - update is not associative: update(a).update(b) != update(b).update(a)
// - update(a ~ b) == update(a).update(b) with ~ as concatenation operator --> useful for streaming
// - update("", 0) --> state unchanged with empty input
hasher& update(void const * bytes, size_t len) noexcept {
const uint8_t * c = static_cast<uint8_t const *>(bytes);
if (len == 0) {
return *this;
}
size_t processed = 0;
// first, fill the existing buffer if it's partial
if (idx > 0) {
size_t to_fill = block_size - idx;
if (to_fill > len) {
to_fill = len;
}
std::memcpy(buffer + idx, c, to_fill);
idx += to_fill;
processed += to_fill;
if (idx == block_size) {
update_block(buffer);
idx = 0;
}
}
// process full blocks from the remaining input
for (; processed + block_size <= len; processed += block_size) {
update_block(c + processed);
}
// buffer any remaining bytes
size_t remaining = len - processed;
if (remaining > 0) {
std::memcpy(buffer, c + processed, remaining);
idx = remaining;
}
return *this;
}
// convenience function for testing only
hasher& update(const std::string & s) noexcept {
return update(s.data(), s.size());
}
// finalize and get the hash value
// note: after calling digest, the hasher state is modified, do not call update() again
size_t digest() noexcept {
// if there are remaining bytes in buffer, fill the rest with zeros and process
if (idx > 0) {
for (size_t i = idx; i < block_size; ++i) {
buffer[i] = 0;
}
update_block(buffer);
idx = 0;
}
return state;
}
private:
// IMPORTANT: block must have at least block_size bytes
void update_block(const uint8_t * block) noexcept {
size_t blk = static_cast<uint32_t>(block[0])
| (static_cast<uint32_t>(block[1]) << 8)
| (static_cast<uint32_t>(block[2]) << 16)
| (static_cast<uint32_t>(block[3]) << 24);
if constexpr (block_size == 8) {
blk = blk | (static_cast<uint64_t>(block[4]) << 32)
| (static_cast<uint64_t>(block[5]) << 40)
| (static_cast<uint64_t>(block[6]) << 48)
| (static_cast<uint64_t>(block[7]) << 56);
}
state ^= blk;
state *= prime;
}
};
} // namespace jinja

View File

@ -114,6 +114,18 @@ static T slice(const T & array, int64_t start, int64_t stop, int64_t step = 1) {
return result;
}
template<typename T>
static value empty_value_fn(const func_args &) {
if constexpr (std::is_same_v<T, value_int>) {
return mk_val<T>(0);
} else if constexpr (std::is_same_v<T, value_float>) {
return mk_val<T>(0.0);
} else if constexpr (std::is_same_v<T, value_bool>) {
return mk_val<T>(false);
} else {
return mk_val<T>();
}
}
template<typename T>
static value test_type_fn(const func_args & args) {
args.ensure_count(1);
@ -128,6 +140,13 @@ static value test_type_fn(const func_args & args) {
JJ_DEBUG("test_type_fn: type=%s or %s result=%d", typeid(T).name(), typeid(U).name(), is_type ? 1 : 0);
return mk_val<value_bool>(is_type);
}
template<typename T, typename U, typename V>
static value test_type_fn(const func_args & args) {
args.ensure_count(1);
bool is_type = is_val<T>(args.get_pos(0)) || is_val<U>(args.get_pos(0)) || is_val<V>(args.get_pos(0));
JJ_DEBUG("test_type_fn: type=%s, %s or %s result=%d", typeid(T).name(), typeid(U).name(), typeid(V).name(), is_type ? 1 : 0);
return mk_val<value_bool>(is_type);
}
template<value_compare_op op>
static value test_compare_fn(const func_args & args) {
args.ensure_count(2, 2);
@ -163,7 +182,7 @@ static value selectattr(const func_args & args) {
args.ensure_vals<value_array, value_string, value_string, value_string>(true, true, false, false);
auto arr = args.get_pos(0)->as_array();
auto attr_name = args.get_pos(1)->as_string().str();
auto attribute = args.get_pos(1);
auto out = mk_val<value_array>();
value val_default = mk_val<value_undefined>();
@ -173,7 +192,7 @@ static value selectattr(const func_args & args) {
if (!is_val<value_object>(item)) {
throw raised_exception("selectattr: item is not an object");
}
value attr_val = item->at(attr_name, val_default);
value attr_val = item->at(attribute, val_default);
bool is_selected = attr_val->as_bool();
if constexpr (is_reject) is_selected = !is_selected;
if (is_selected) out->push_back(item);
@ -217,7 +236,7 @@ static value selectattr(const func_args & args) {
if (!is_val<value_object>(item)) {
throw raised_exception("selectattr: item is not an object");
}
value attr_val = item->at(attr_name, val_default);
value attr_val = item->at(attribute, val_default);
func_args test_args(args.ctx);
test_args.push_back(attr_val); // attribute value
test_args.push_back(extra_arg); // extra argument
@ -347,8 +366,8 @@ const func_builtins & global_builtins() {
{"test_is_integer", test_type_fn<value_int>},
{"test_is_float", test_type_fn<value_float>},
{"test_is_number", test_type_fn<value_int, value_float>},
{"test_is_iterable", test_type_fn<value_array, value_string>},
{"test_is_sequence", test_type_fn<value_array, value_string>},
{"test_is_iterable", test_type_fn<value_array, value_string, value_undefined>},
{"test_is_sequence", test_type_fn<value_array, value_string, value_undefined>},
{"test_is_mapping", test_type_fn<value_object>},
{"test_is_lower", [](const func_args & args) -> value {
args.ensure_vals<value_string>();
@ -741,6 +760,7 @@ const func_builtins & value_array_t::get_builtins() const {
args.ensure_count(1, 4);
args.ensure_vals<value_array, value_int, value_int, value_int>(true, true, false, false);
auto val = args.get_pos(0);
auto arg0 = args.get_pos(1);
auto arg1 = args.get_pos(2, mk_val<value_undefined>());
auto arg2 = args.get_pos(3, mk_val<value_undefined>());
@ -762,10 +782,8 @@ const func_builtins & value_array_t::get_builtins() const {
if (step == 0) {
throw raised_exception("slice step cannot be zero");
}
auto arr = slice(args.get_pos(0)->as_array(), start, stop, step);
auto res = mk_val<value_array>();
res->val_arr = std::move(arr);
return res;
auto arr = slice(val->as_array(), start, stop, step);
return is_val<value_tuple>(val) ? mk_val<value_tuple>(std::move(arr)) : mk_val<value_array>(std::move(arr));
}},
{"selectattr", selectattr<false>},
{"select", selectattr<false>},
@ -785,15 +803,14 @@ const func_builtins & value_array_t::get_builtins() const {
}
const int64_t attr_int = attr_is_int ? attribute->as_int() : 0;
const std::string delim = val_delim->is_undefined() ? "" : val_delim->as_string().str();
const std::string attr_name = attribute->is_undefined() ? "" : attribute->as_string().str();
std::string result;
for (size_t i = 0; i < arr.size(); ++i) {
value val_arr = arr[i];
if (!attribute->is_undefined()) {
if (attr_is_int && is_val<value_array>(val_arr)) {
val_arr = val_arr->at(attr_int);
} else if (!attr_is_int && !attr_name.empty() && is_val<value_object>(val_arr)) {
val_arr = val_arr->at(attr_name);
} else if (!attr_is_int && is_val<value_object>(val_arr)) {
val_arr = val_arr->at(attribute);
}
}
if (!is_val<value_string>(val_arr) && !is_val<value_int>(val_arr) && !is_val<value_float>(val_arr)) {
@ -808,9 +825,7 @@ const func_builtins & value_array_t::get_builtins() const {
}},
{"string", [](const func_args & args) -> value {
args.ensure_vals<value_array>();
auto str = mk_val<value_string>();
gather_string_parts_recursive(args.get_pos(0), str);
return str;
return mk_val<value_string>(args.get_pos(0)->as_string());
}},
{"tojson", tojson},
{"map", [](const func_args & args) -> value {
@ -821,26 +836,26 @@ const func_builtins & value_array_t::get_builtins() const {
if (!is_val<value_kwarg>(args.get_args().at(1))) {
throw not_implemented_exception("map: filter-mapping not implemented");
}
value val = args.get_pos(0);
value attribute = args.get_kwarg_or_pos("attribute", 1);
const bool attr_is_int = is_val<value_int>(attribute);
if (!is_val<value_string>(attribute) && !attr_is_int) {
throw raised_exception("map: attribute must be string or integer");
}
const int64_t attr_int = attr_is_int ? attribute->as_int() : 0;
const std::string attr_name = attribute->as_string().str();
value default_val = args.get_kwarg("default", mk_val<value_undefined>());
auto out = mk_val<value_array>();
auto arr = args.get_pos(0)->as_array();
auto arr = val->as_array();
for (const auto & item : arr) {
value attr_val;
if (attr_is_int) {
attr_val = is_val<value_array>(item) ? item->at(attr_int, default_val) : default_val;
} else {
attr_val = is_val<value_object>(item) ? item->at(attr_name, default_val) : default_val;
attr_val = is_val<value_object>(item) ? item->at(attribute, default_val) : default_val;
}
out->push_back(attr_val);
}
return out;
return is_val<value_tuple>(val) ? mk_val<value_tuple>(std::move(out->as_array())) : out;
}},
{"append", [](const func_args & args) -> value {
args.ensure_count(2);
@ -867,6 +882,7 @@ const func_builtins & value_array_t::get_builtins() const {
if (!is_val<value_array>(args.get_pos(0))) {
throw raised_exception("sort: first argument must be an array");
}
value val = args.get_pos(0);
value val_reverse = args.get_kwarg_or_pos("reverse", 1);
value val_case = args.get_kwarg_or_pos("case_sensitive", 2);
value attribute = args.get_kwarg_or_pos("attribute", 3);
@ -875,8 +891,7 @@ const func_builtins & value_array_t::get_builtins() const {
const bool reverse = val_reverse->as_bool(); // undefined == false
const bool attr_is_int = is_val<value_int>(attribute);
const int64_t attr_int = attr_is_int ? attribute->as_int() : 0;
const std::string attr_name = attribute->is_undefined() ? "" : attribute->as_string().str();
std::vector<value> arr = cast_val<value_array>(args.get_pos(0))->as_array(); // copy
std::vector<value> arr = val->as_array(); // copy
std::sort(arr.begin(), arr.end(),[&](const value & a, const value & b) {
value val_a = a;
value val_b = b;
@ -884,22 +899,23 @@ const func_builtins & value_array_t::get_builtins() const {
if (attr_is_int && is_val<value_array>(a) && is_val<value_array>(b)) {
val_a = a->at(attr_int);
val_b = b->at(attr_int);
} else if (!attr_is_int && !attr_name.empty() && is_val<value_object>(a) && is_val<value_object>(b)) {
val_a = a->at(attr_name);
val_b = b->at(attr_name);
} else if (!attr_is_int && is_val<value_object>(a) && is_val<value_object>(b)) {
val_a = a->at(attribute);
val_b = b->at(attribute);
} else {
throw raised_exception("sort: unsupported object attribute comparison");
throw raised_exception("sort: unsupported object attribute comparison between " + a->type() + " and " + b->type());
}
}
return value_compare(val_a, val_b, reverse ? value_compare_op::gt : value_compare_op::lt);
});
return mk_val<value_array>(arr);
return is_val<value_tuple>(val) ? mk_val<value_tuple>(std::move(arr)) : mk_val<value_array>(std::move(arr));
}},
{"reverse", [](const func_args & args) -> value {
args.ensure_vals<value_array>();
std::vector<value> arr = cast_val<value_array>(args.get_pos(0))->as_array(); // copy
value val = args.get_pos(0);
std::vector<value> arr = val->as_array(); // copy
std::reverse(arr.begin(), arr.end());
return mk_val<value_array>(arr);
return is_val<value_tuple>(val) ? mk_val<value_tuple>(std::move(arr)) : mk_val<value_array>(std::move(arr));
}},
{"unique", [](const func_args &) -> value {
throw not_implemented_exception("Array unique builtin not implemented");
@ -930,7 +946,7 @@ const func_builtins & value_object_t::get_builtins() const {
default_val = args.get_pos(2);
}
const value obj = args.get_pos(0);
std::string key = args.get_pos(1)->as_string().str();
const value key = args.get_pos(1);
return obj->at(key, default_val);
}},
{"keys", [](const func_args & args) -> value {
@ -938,7 +954,7 @@ const func_builtins & value_object_t::get_builtins() const {
const auto & obj = args.get_pos(0)->as_ordered_object();
auto result = mk_val<value_array>();
for (const auto & pair : obj) {
result->push_back(mk_val<value_string>(pair.first));
result->push_back(pair.first);
}
return result;
}},
@ -956,15 +972,16 @@ const func_builtins & value_object_t::get_builtins() const {
const auto & obj = args.get_pos(0)->as_ordered_object();
auto result = mk_val<value_array>();
for (const auto & pair : obj) {
auto item = mk_val<value_array>();
item->push_back(mk_val<value_string>(pair.first));
item->push_back(pair.second);
auto item = mk_val<value_tuple>(pair);
result->push_back(std::move(item));
}
return result;
}},
{"tojson", tojson},
{"string", tojson},
{"string", [](const func_args & args) -> value {
args.ensure_vals<value_object>();
return mk_val<value_string>(args.get_pos(0)->as_string());
}},
{"length", [](const func_args & args) -> value {
args.ensure_vals<value_object>();
const auto & obj = args.get_pos(0)->as_ordered_object();
@ -985,11 +1002,11 @@ const func_builtins & value_object_t::get_builtins() const {
const bool reverse = val_reverse->as_bool(); // undefined == false
const bool by_value = is_val<value_string>(val_by) && val_by->as_string().str() == "value" ? true : false;
auto result = mk_val<value_object>(val_input); // copy
std::sort(result->val_obj.ordered.begin(), result->val_obj.ordered.end(), [&](const auto & a, const auto & b) {
std::sort(result->val_obj.begin(), result->val_obj.end(), [&](const auto & a, const auto & b) {
if (by_value) {
return value_compare(a.second, b.second, reverse ? value_compare_op::gt : value_compare_op::lt);
} else {
return reverse ? a.first > b.first : a.first < b.first;
return value_compare(a.first, b.first, reverse ? value_compare_op::gt : value_compare_op::lt);
}
});
return result;
@ -1005,7 +1022,12 @@ const func_builtins & value_none_t::get_builtins() const {
static const func_builtins builtins = {
{"default", default_value},
{"tojson", tojson},
{"string", [](const func_args &) -> value { return mk_val<value_string>("None"); }}
{"string", [](const func_args &) -> value {
return mk_val<value_string>("None");
}},
{"safe", [](const func_args &) -> value {
return mk_val<value_string>("None");
}},
};
return builtins;
}
@ -1014,10 +1036,33 @@ const func_builtins & value_none_t::get_builtins() const {
const func_builtins & value_undefined_t::get_builtins() const {
static const func_builtins builtins = {
{"default", default_value},
{"tojson", [](const func_args & args) -> value {
args.ensure_vals<value_undefined>();
return mk_val<value_string>("null");
}},
{"capitalize", empty_value_fn<value_string>},
{"first", empty_value_fn<value_undefined>},
{"items", empty_value_fn<value_array>},
{"join", empty_value_fn<value_string>},
{"last", empty_value_fn<value_undefined>},
{"length", empty_value_fn<value_int>},
{"list", empty_value_fn<value_array>},
{"lower", empty_value_fn<value_string>},
{"map", empty_value_fn<value_array>},
{"max", empty_value_fn<value_undefined>},
{"min", empty_value_fn<value_undefined>},
{"reject", empty_value_fn<value_array>},
{"rejectattr", empty_value_fn<value_array>},
{"replace", empty_value_fn<value_string>},
{"reverse", empty_value_fn<value_array>},
{"safe", empty_value_fn<value_string>},
{"select", empty_value_fn<value_array>},
{"selectattr", empty_value_fn<value_array>},
{"sort", empty_value_fn<value_array>},
{"string", empty_value_fn<value_string>},
{"strip", empty_value_fn<value_string>},
{"sum", empty_value_fn<value_int>},
{"title", empty_value_fn<value_string>},
{"truncate", empty_value_fn<value_string>},
{"unique", empty_value_fn<value_array>},
{"upper", empty_value_fn<value_string>},
{"wordcount", empty_value_fn<value_int>},
};
return builtins;
}
@ -1134,6 +1179,8 @@ void global_from_json(context & ctx, const nlohmann::ordered_json & json_obj, bo
}
}
// recursively convert value to JSON string
// TODO: avoid circular references
static void value_to_json_internal(std::ostringstream & oss, const value & val, int curr_lvl, int indent, const std::string_view item_sep, const std::string_view key_sep) {
auto indent_str = [indent, curr_lvl]() -> std::string {
return (indent > 0) ? std::string(curr_lvl * indent, ' ') : "";
@ -1196,7 +1243,8 @@ static void value_to_json_internal(std::ostringstream & oss, const value & val,
size_t i = 0;
for (const auto & pair : obj) {
oss << indent_str() << (indent > 0 ? std::string(indent, ' ') : "");
oss << "\"" << pair.first << "\"" << key_sep;
value_to_json_internal(oss, mk_val<value_string>(pair.first->as_string().str()), curr_lvl + 1, indent, item_sep, key_sep);
oss << key_sep;
value_to_json_internal(oss, pair.second, curr_lvl + 1, indent, item_sep, key_sep);
if (i < obj.size() - 1) {
oss << item_sep;
@ -1219,4 +1267,19 @@ std::string value_to_json(const value & val, int indent, const std::string_view
return oss.str();
}
// TODO: avoid circular references
std::string value_to_string_repr(const value & val) {
if (is_val<value_string>(val)) {
const std::string val_str = val->as_string().str();
if (val_str.find('\'') != std::string::npos) {
return value_to_json(val);
} else {
return "'" + val_str + "'";
}
} else {
return val->as_repr();
}
}
} // namespace jinja

View File

@ -1,8 +1,10 @@
#pragma once
#include "string.h"
#include "utils.h"
#include <algorithm>
#include <cmath>
#include <cstdint>
#include <functional>
#include <map>
@ -93,7 +95,8 @@ void global_from_json(context & ctx, const T_JSON & json_obj, bool mark_input);
struct func_args; // function argument values
using func_handler = std::function<value(const func_args &)>;
using func_hptr = value(const func_args &);
using func_handler = std::function<func_hptr>;
using func_builtins = std::map<std::string, func_handler>;
enum value_compare_op { eq, ge, gt, lt, ne };
@ -103,28 +106,9 @@ struct value_t {
int64_t val_int;
double val_flt;
string val_str;
bool val_bool;
std::vector<value> val_arr;
struct map {
// once set to true, all keys must be numeric
// caveat: we only allow either all numeric keys or all non-numeric keys
// for now, this only applied to for_statement in case of iterating over object keys/items
bool is_key_numeric = false;
std::map<std::string, value> unordered;
std::vector<std::pair<std::string, value>> ordered;
void insert(const std::string & key, const value & val) {
if (unordered.find(key) != unordered.end()) {
// if key exists, remove from ordered list
ordered.erase(std::remove_if(ordered.begin(), ordered.end(),
[&](const std::pair<std::string, value> & p) { return p.first == key; }),
ordered.end());
}
unordered[key] = val;
ordered.push_back({key, val});
}
} val_obj;
std::vector<std::pair<value, value>> val_obj;
func_handler val_func;
@ -139,6 +123,7 @@ struct value_t {
value_t(const value_t &) = default;
virtual ~value_t() = default;
// Note: only for debugging and error reporting purposes
virtual std::string type() const { return ""; }
virtual int64_t as_int() const { throw std::runtime_error(type() + " is not an int value"); }
@ -146,7 +131,7 @@ struct value_t {
virtual string as_string() const { throw std::runtime_error(type() + " is not a string value"); }
virtual bool as_bool() const { throw std::runtime_error(type() + " is not a bool value"); }
virtual const std::vector<value> & as_array() const { throw std::runtime_error(type() + " is not an array value"); }
virtual const std::vector<std::pair<std::string, value>> & as_ordered_object() const { throw std::runtime_error(type() + " is not an object value"); }
virtual const std::vector<std::pair<value, value>> & as_ordered_object() const { throw std::runtime_error(type() + " is not an object value"); }
virtual value invoke(const func_args &) const { throw std::runtime_error(type() + " is not a function value"); }
virtual bool is_none() const { return false; }
virtual bool is_undefined() const { return false; }
@ -154,43 +139,66 @@ struct value_t {
throw std::runtime_error("No builtins available for type " + type());
}
virtual bool has_key(const std::string & key) {
return val_obj.unordered.find(key) != val_obj.unordered.end();
}
virtual value & at(const std::string & key, value & default_val) {
auto it = val_obj.unordered.find(key);
if (it == val_obj.unordered.end()) {
return default_val;
}
return val_obj.unordered.at(key);
}
virtual value & at(const std::string & key) {
auto it = val_obj.unordered.find(key);
if (it == val_obj.unordered.end()) {
throw std::runtime_error("Key '" + key + "' not found in value of type " + type());
}
return val_obj.unordered.at(key);
}
virtual value & at(int64_t index, value & default_val) {
if (index < 0) {
index += val_arr.size();
}
if (index < 0 || static_cast<size_t>(index) >= val_arr.size()) {
return default_val;
}
return val_arr[index];
}
virtual value & at(int64_t index) {
if (index < 0) {
index += val_arr.size();
}
if (index < 0 || static_cast<size_t>(index) >= val_arr.size()) {
throw std::runtime_error("Index " + std::to_string(index) + " out of bounds for array of size " + std::to_string(val_arr.size()));
}
return val_arr[index];
}
virtual bool has_key(const value &) { throw std::runtime_error(type() + " is not an object value"); }
virtual void insert(const value & /* key */, const value & /* val */) { throw std::runtime_error(type() + " is not an object value"); }
virtual value & at(const value & /* key */, value & /* default_val */) { throw std::runtime_error(type() + " is not an object value"); }
virtual value & at(const value & /* key */) { throw std::runtime_error(type() + " is not an object value"); }
virtual value & at(const std::string & /* key */, value & /* default_val */) { throw std::runtime_error(type() + " is not an object value"); }
virtual value & at(const std::string & /* key */) { throw std::runtime_error(type() + " is not an object value"); }
virtual value & at(int64_t /* idx */, value & /* default_val */) { throw std::runtime_error(type() + " is not an array value"); }
virtual value & at(int64_t /* idx */) { throw std::runtime_error(type() + " is not an array value"); }
virtual bool is_numeric() const { return false; }
virtual bool is_hashable() const { return false; }
virtual bool is_immutable() const { return true; }
virtual hasher unique_hash() const noexcept = 0;
// TODO: C++20 <=> operator
// NOTE: We are treating == as equivalent (for normal comparisons) and != as strict nonequal (for strict (is) comparisons)
virtual bool operator==(const value_t & other) const { return equivalent(other); }
virtual bool operator!=(const value_t & other) const { return nonequal(other); }
// Note: only for debugging purposes
virtual std::string as_repr() const { return as_string().str(); }
protected:
virtual bool equivalent(const value_t &) const = 0;
virtual bool nonequal(const value_t & other) const { return !equivalent(other); }
};
//
// utils
//
const func_builtins & global_builtins();
std::string value_to_json(const value & val, int indent = -1, const std::string_view item_sep = ", ", const std::string_view key_sep = ": ");
// Note: only used for debugging purposes
std::string value_to_string_repr(const value & val);
struct not_implemented_exception : public std::runtime_error {
not_implemented_exception(const std::string & msg) : std::runtime_error("NotImplemented: " + msg) {}
};
struct value_hasher {
size_t operator()(const value & val) const noexcept {
return val->unique_hash().digest();
}
};
struct value_equivalence {
bool operator()(const value & lhs, const value & rhs) const {
return *lhs == *rhs;
}
bool operator()(const std::pair<value, value> & lhs, const std::pair<value, value> & rhs) const {
return *(lhs.first) == *(rhs.first) && *(lhs.second) == *(rhs.second);
}
};
struct value_equality {
bool operator()(const value & lhs, const value & rhs) const {
return !(*lhs != *rhs);
}
};
//
@ -198,24 +206,49 @@ struct value_t {
//
struct value_int_t : public value_t {
value_int_t(int64_t v) { val_int = v; }
value_int_t(int64_t v) {
val_int = v;
val_flt = static_cast<double>(v);
if (static_cast<int64_t>(val_flt) != v) {
val_flt = v < 0 ? -INFINITY : INFINITY;
}
}
virtual std::string type() const override { return "Integer"; }
virtual int64_t as_int() const override { return val_int; }
virtual double as_float() const override { return static_cast<double>(val_int); }
virtual double as_float() const override { return val_flt; }
virtual string as_string() const override { return std::to_string(val_int); }
virtual bool as_bool() const override {
return val_int != 0;
}
virtual const func_builtins & get_builtins() const override;
virtual bool is_numeric() const override { return true; }
virtual bool is_hashable() const override { return true; }
virtual hasher unique_hash() const noexcept override {
return hasher(typeid(*this))
.update(&val_int, sizeof(val_int))
.update(&val_flt, sizeof(val_flt));
}
protected:
virtual bool equivalent(const value_t & other) const override {
return other.is_numeric() && val_int == other.val_int && val_flt == other.val_flt;
}
virtual bool nonequal(const value_t & other) const override {
return !(typeid(*this) == typeid(other) && val_int == other.val_int);
}
};
using value_int = std::shared_ptr<value_int_t>;
struct value_float_t : public value_t {
value_float_t(double v) { val_flt = v; }
value val;
value_float_t(double v) {
val_flt = v;
val_int = std::isfinite(v) ? static_cast<int64_t>(v) : 0;
val = mk_val<value_int>(val_int);
}
virtual std::string type() const override { return "Float"; }
virtual double as_float() const override { return val_flt; }
virtual int64_t as_int() const override { return static_cast<int64_t>(val_flt); }
virtual int64_t as_int() const override { return val_int; }
virtual string as_string() const override {
std::string out = std::to_string(val_flt);
out.erase(out.find_last_not_of('0') + 1, std::string::npos); // remove trailing zeros
@ -226,6 +259,24 @@ struct value_float_t : public value_t {
return val_flt != 0.0;
}
virtual const func_builtins & get_builtins() const override;
virtual bool is_numeric() const override { return true; }
virtual bool is_hashable() const override { return true; }
virtual hasher unique_hash() const noexcept override {
if (static_cast<double>(val_int) == val_flt) {
return val->unique_hash();
} else {
return hasher(typeid(*this))
.update(&val_int, sizeof(val_int))
.update(&val_flt, sizeof(val_flt));
}
}
protected:
virtual bool equivalent(const value_t & other) const override {
return other.is_numeric() && val_int == other.val_int && val_flt == other.val_flt;
}
virtual bool nonequal(const value_t & other) const override {
return !(typeid(*this) == typeid(other) && val_flt == other.val_flt);
}
};
using value_float = std::shared_ptr<value_float_t>;
@ -247,19 +298,49 @@ struct value_string_t : public value_t {
return val_str.length() > 0;
}
virtual const func_builtins & get_builtins() const override;
virtual bool is_hashable() const override { return true; }
virtual hasher unique_hash() const noexcept override {
const auto type_hash = typeid(*this).hash_code();
auto hash = hasher();
hash.update(&type_hash, sizeof(type_hash));
val_str.hash_update(hash);
return hash;
}
void mark_input() {
val_str.mark_input();
}
protected:
virtual bool equivalent(const value_t & other) const override {
return typeid(*this) == typeid(other) && val_str.str() == other.val_str.str();
}
};
using value_string = std::shared_ptr<value_string_t>;
struct value_bool_t : public value_t {
value_bool_t(bool v) { val_bool = v; }
value val;
value_bool_t(bool v) {
val_int = static_cast<int64_t>(v);
val_flt = static_cast<double>(v);
val = mk_val<value_int>(val_int);
}
virtual std::string type() const override { return "Boolean"; }
virtual bool as_bool() const override { return val_bool; }
virtual string as_string() const override { return std::string(val_bool ? "True" : "False"); }
virtual int64_t as_int() const override { return val_int; }
virtual bool as_bool() const override { return val_int; }
virtual string as_string() const override { return std::string(val_int ? "True" : "False"); }
virtual const func_builtins & get_builtins() const override;
virtual bool is_numeric() const override { return true; }
virtual bool is_hashable() const override { return true; }
virtual hasher unique_hash() const noexcept override {
return val->unique_hash();
}
protected:
virtual bool equivalent(const value_t & other) const override {
return other.is_numeric() && val_int == other.val_int && val_flt == other.val_flt;
}
virtual bool nonequal(const value_t & other) const override {
return !(typeid(*this) == typeid(other) && val_int == other.val_int);
}
};
using value_bool = std::shared_ptr<value_bool_t>;
@ -269,13 +350,34 @@ struct value_array_t : public value_t {
value_array_t(value & v) {
val_arr = v->val_arr;
}
value_array_t(std::vector<value> && arr) {
val_arr = arr;
}
value_array_t(const std::vector<value> & arr) {
val_arr = arr;
}
void reverse() { std::reverse(val_arr.begin(), val_arr.end()); }
void push_back(const value & val) { val_arr.push_back(val); }
void push_back(value && val) { val_arr.push_back(std::move(val)); }
void reverse() {
if (is_immutable()) {
throw std::runtime_error("Attempting to modify immutable type");
}
std::reverse(val_arr.begin(), val_arr.end());
}
void push_back(const value & val) {
if (is_immutable()) {
throw std::runtime_error("Attempting to modify immutable type");
}
val_arr.push_back(val);
}
void push_back(value && val) {
if (is_immutable()) {
throw std::runtime_error("Attempting to modify immutable type");
}
val_arr.push_back(std::move(val));
}
value pop_at(int64_t index) {
if (is_immutable()) {
throw std::runtime_error("Attempting to modify immutable type");
}
if (index < 0) {
index = static_cast<int64_t>(val_arr.size()) + index;
}
@ -287,64 +389,225 @@ struct value_array_t : public value_t {
return val;
}
virtual std::string type() const override { return "Array"; }
virtual bool is_immutable() const override { return false; }
virtual const std::vector<value> & as_array() const override { return val_arr; }
virtual string as_string() const override {
const bool immutable = is_immutable();
std::ostringstream ss;
ss << "[";
ss << (immutable ? "(" : "[");
for (size_t i = 0; i < val_arr.size(); i++) {
if (i > 0) ss << ", ";
ss << val_arr.at(i)->as_repr();
value val = val_arr.at(i);
ss << value_to_string_repr(val);
}
ss << "]";
if (immutable && val_arr.size() == 1) {
ss << ",";
}
ss << (immutable ? ")" : "]");
return ss.str();
}
virtual bool as_bool() const override {
return !val_arr.empty();
}
virtual value & at(int64_t index, value & default_val) override {
if (index < 0) {
index += val_arr.size();
}
if (index < 0 || static_cast<size_t>(index) >= val_arr.size()) {
return default_val;
}
return val_arr[index];
}
virtual value & at(int64_t index) override {
if (index < 0) {
index += val_arr.size();
}
if (index < 0 || static_cast<size_t>(index) >= val_arr.size()) {
throw std::runtime_error("Index " + std::to_string(index) + " out of bounds for array of size " + std::to_string(val_arr.size()));
}
return val_arr[index];
}
virtual const func_builtins & get_builtins() const override;
virtual bool is_hashable() const override {
if (std::all_of(val_arr.begin(), val_arr.end(), [&](auto & val) -> bool {
return val->is_immutable() && val->is_hashable();
})) {
return true;
}
return false;
}
virtual hasher unique_hash() const noexcept override {
auto hash = hasher(typeid(*this));
for (const auto & val : val_arr) {
// must use digest to prevent problems from "concatenation" property of hasher
// for ex. hash of [ "ab", "c" ] should be different from [ "a", "bc" ]
const size_t val_hash = val->unique_hash().digest();
hash.update(&val_hash, sizeof(size_t));
}
return hash;
}
protected:
virtual bool equivalent(const value_t & other) const override {
return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_arr.begin(), val_arr.end(), other.val_arr.begin(), value_equivalence());
}
};
using value_array = std::shared_ptr<value_array_t>;
struct value_tuple_t : public value_array_t {
value_tuple_t(value & v) {
val_arr = v->val_arr;
}
value_tuple_t(std::vector<value> && arr) {
val_arr = arr;
}
value_tuple_t(const std::vector<value> & arr) {
val_arr = arr;
}
value_tuple_t(const std::pair<value, value> & pair) {
val_arr.push_back(pair.first);
val_arr.push_back(pair.second);
}
virtual std::string type() const override { return "Tuple"; }
virtual bool is_immutable() const override { return true; }
};
using value_tuple = std::shared_ptr<value_tuple_t>;
struct value_object_t : public value_t {
std::unordered_map<value, value, value_hasher, value_equivalence> unordered;
bool has_builtins = true; // context and loop objects do not have builtins
value_object_t() = default;
value_object_t(value & v) {
val_obj = v->val_obj;
}
value_object_t(const std::map<std::string, value> & obj) {
for (const auto & pair : obj) {
val_obj.insert(pair.first, pair.second);
for (const auto & pair : val_obj) {
unordered[pair.first] = pair.second;
}
}
value_object_t(const std::vector<std::pair<std::string, value>> & obj) {
value_object_t(const std::map<value, value> & obj) {
for (const auto & pair : obj) {
val_obj.insert(pair.first, pair.second);
insert(pair.first, pair.second);
}
}
value_object_t(const std::vector<std::pair<value, value>> & obj) {
for (const auto & pair : obj) {
insert(pair.first, pair.second);
}
}
void insert(const std::string & key, const value & val) {
val_obj.insert(key, val);
insert(mk_val<value_string>(key), val);
}
virtual std::string type() const override { return "Object"; }
virtual const std::vector<std::pair<std::string, value>> & as_ordered_object() const override { return val_obj.ordered; }
virtual bool is_immutable() const override { return false; }
virtual const std::vector<std::pair<value, value>> & as_ordered_object() const override { return val_obj; }
virtual string as_string() const override {
std::ostringstream ss;
ss << "{";
for (size_t i = 0; i < val_obj.size(); i++) {
if (i > 0) ss << ", ";
auto & [key, val] = val_obj.at(i);
ss << value_to_string_repr(key) << ": " << value_to_string_repr(val);
}
ss << "}";
return ss.str();
}
virtual bool as_bool() const override {
return !val_obj.unordered.empty();
return !unordered.empty();
}
virtual bool has_key(const value & key) override {
if (!key->is_immutable() || !key->is_hashable()) {
throw std::runtime_error("Object key of unhashable type: " + key->type());
}
return unordered.find(key) != unordered.end();
}
virtual void insert(const value & key, const value & val) override {
bool replaced = false;
if (is_immutable()) {
throw std::runtime_error("Attempting to modify immutable type");
}
if (has_key(key)) {
// if key exists, replace value in ordered list instead of appending
for (auto & pair : val_obj) {
if (*(pair.first) == *key) {
pair.second = val;
replaced = true;
break;
}
}
}
unordered[key] = val;
if (!replaced) {
val_obj.push_back({key, val});
}
}
virtual value & at(const value & key, value & default_val) override {
if (!has_key(key)) {
return default_val;
}
return unordered.at(key);
}
virtual value & at(const value & key) override {
if (!has_key(key)) {
throw std::runtime_error("Key '" + key->as_string().str() + "' not found in value of type " + type());
}
return unordered.at(key);
}
virtual value & at(const std::string & key, value & default_val) override {
value key_val = mk_val<value_string>(key);
return at(key_val, default_val);
}
virtual value & at(const std::string & key) override {
value key_val = mk_val<value_string>(key);
return at(key_val);
}
virtual const func_builtins & get_builtins() const override;
virtual bool is_hashable() const override {
if (std::all_of(val_obj.begin(), val_obj.end(), [&](auto & pair) -> bool {
const auto & val = pair.second;
return val->is_immutable() && val->is_hashable();
})) {
return true;
}
return false;
}
virtual hasher unique_hash() const noexcept override {
auto hash = hasher(typeid(*this));
for (const auto & [key, val] : val_obj) {
// must use digest to prevent problems from "concatenation" property of hasher
// for ex. hash of key="ab", value="c" should be different from key="a", value="bc"
const size_t key_hash = key->unique_hash().digest();
const size_t val_hash = val->unique_hash().digest();
hash.update(&key_hash, sizeof(key_hash));
hash.update(&val_hash, sizeof(val_hash));
}
return hash;
}
protected:
virtual bool equivalent(const value_t & other) const override {
return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_obj.begin(), val_obj.end(), other.val_obj.begin(), value_equivalence());
}
};
using value_object = std::shared_ptr<value_object_t>;
//
// null and undefined types
// none and undefined types
//
struct value_none_t : public value_t {
virtual std::string type() const override { return "None"; }
virtual bool is_none() const override { return true; }
virtual bool as_bool() const override { return false; }
virtual string as_string() const override { return string("None"); }
virtual string as_string() const override { return string(type()); }
virtual std::string as_repr() const override { return type(); }
virtual const func_builtins & get_builtins() const override;
virtual bool is_hashable() const override { return true; }
virtual hasher unique_hash() const noexcept override {
return hasher(typeid(*this));
}
protected:
virtual bool equivalent(const value_t & other) const override {
return typeid(*this) == typeid(other);
}
};
using value_none = std::shared_ptr<value_none_t>;
@ -356,6 +619,13 @@ struct value_undefined_t : public value_t {
virtual bool as_bool() const override { return false; }
virtual std::string as_repr() const override { return type(); }
virtual const func_builtins & get_builtins() const override;
virtual hasher unique_hash() const noexcept override {
return hasher(typeid(*this));
}
protected:
virtual bool equivalent(const value_t & other) const override {
return is_undefined() == other.is_undefined();
}
};
using value_undefined = std::shared_ptr<value_undefined_t>;
@ -436,7 +706,23 @@ struct value_func_t : public value_t {
return val_func(new_args);
}
virtual std::string type() const override { return "Function"; }
virtual std::string as_repr() const override { return type(); }
virtual std::string as_repr() const override { return type() + "<" + name + ">(" + (arg0 ? arg0->as_repr() : "") + ")"; }
virtual bool is_hashable() const override { return false; }
virtual hasher unique_hash() const noexcept override {
// Note: this is unused for now, we don't support function as object keys
// use function pointer as unique identifier
const auto target = val_func.target<func_hptr>();
return hasher(typeid(*this)).update(&target, sizeof(target));
}
protected:
virtual bool equivalent(const value_t & other) const override {
// Note: this is unused for now, we don't support function as object keys
// compare function pointers
// (val_func == other.val_func does not work as std::function::operator== is only used for nullptr check)
const auto target_this = this->val_func.target<func_hptr>();
const auto target_other = other.val_func.target<func_hptr>();
return typeid(*this) == typeid(other) && target_this == target_other;
}
};
using value_func = std::shared_ptr<value_func_t>;
@ -447,18 +733,21 @@ struct value_kwarg_t : public value_t {
value_kwarg_t(const std::string & k, const value & v) : key(k), val(v) {}
virtual std::string type() const override { return "KwArg"; }
virtual std::string as_repr() const override { return type(); }
virtual bool is_hashable() const override { return true; }
virtual hasher unique_hash() const noexcept override {
const auto type_hash = typeid(*this).hash_code();
auto hash = val->unique_hash();
hash.update(&type_hash, sizeof(type_hash))
.update(key.data(), key.size());
return hash;
}
protected:
virtual bool equivalent(const value_t & other) const override {
const value_kwarg_t & other_val = static_cast<const value_kwarg_t &>(other);
return typeid(*this) == typeid(other) && key == other_val.key && val == other_val.val;
}
};
using value_kwarg = std::shared_ptr<value_kwarg_t>;
// utils
const func_builtins & global_builtins();
std::string value_to_json(const value & val, int indent = -1, const std::string_view item_sep = ", ", const std::string_view key_sep = ": ");
struct not_implemented_exception : public std::runtime_error {
not_implemented_exception(const std::string & msg) : std::runtime_error("NotImplemented: " + msg) {}
};
} // namespace jinja

View File

@ -192,12 +192,12 @@ void common_ngram_cache_draft(
break;
}
LOG(" - draft candidate: token=%d\n", drafted_token);
LOG_DBG(" - draft candidate: token=%d\n", drafted_token);
draft.push_back(drafted_token);
}
}
void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename) {
void common_ngram_cache_save(common_ngram_cache & ngram_cache, const std::string & filename) {
std::ofstream file_out(filename, std::ios::binary);
for (std::pair<common_ngram, common_ngram_cache_part> item : ngram_cache) {
const common_ngram ngram = item.first;
@ -217,10 +217,9 @@ void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & fil
file_out.write(reinterpret_cast<const char *>(&count), sizeof(int32_t));
}
}
}
common_ngram_cache common_ngram_cache_load(std::string & filename) {
common_ngram_cache common_ngram_cache_load(const std::string & filename) {
std::ifstream hashmap_file(filename, std::ios::binary);
if (!hashmap_file) {
throw std::ifstream::failure("Unable to open file " + filename);

View File

@ -88,12 +88,12 @@ void common_ngram_cache_draft(
// Save an ngram cache to a file.
// ngram_cache: the ngram cache to save.
// filename: the path under which to save the ngram cache.
void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename);
void common_ngram_cache_save(common_ngram_cache & ngram_cache, const std::string & filename);
// Load an ngram cache saved with common_ngram_cache_save.
// filename: the path from which to load the ngram cache.
// returns: an ngram cache containing the information saved to filename.
common_ngram_cache common_ngram_cache_load(std::string & filename);
common_ngram_cache common_ngram_cache_load(const std::string & filename);
// Merge two ngram caches.
// ngram_cache_target: the ngram cache to which to add the information from ngram_cache_add.

367
common/ngram-map.cpp Normal file
View File

@ -0,0 +1,367 @@
#include "common.h"
#include "log.h"
#include "ngram-map.h"
#include <cinttypes>
#include <cstdint>
#include <cstdio>
#include <sstream>
// n-gram simple
//
/**
* Perform speculative generation using the model's own token history.
* Searches for a matching pattern in the token history and returns draft tokens.
*
* @param state Current state of this implementation
* @param tokens Token history to search in
* @param sampled Last sampled token
* @return Vector of draft tokens, empty if no matching pattern is found
*/
llama_tokens common_ngram_simple_draft(
common_ngram_simple_state & state,
const llama_tokens & tokens, llama_token sampled) {
// Simple implementation of self-speculative decoding without a draft model.
//
const size_t cur_len = tokens.size();
// Only check every check_rate tokens to save compute
// i.e., perform check if (cur_len - idx_last_check) >= check_rate
if (state.idx_last_check + state.config.check_rate > cur_len) {
llama_tokens draft_tokens;
return draft_tokens;
}
size_t n_draft_min = state.config.size_ngram; // size of n-gram to lookup in token history
size_t n_draft_max = state.config.size_mgram; // the m-gram following the found n-gram is used for draft
// vector for tokens we want to verify.
// return empty vector if there is no match.
llama_tokens draft_tokens;
// We need at least n_draft_min + n_draft_max + 1 tokens.
if (cur_len <= static_cast<size_t>(n_draft_min + n_draft_max + 1)) {
return draft_tokens;
}
// pattern search
llama_tokens pattern;
pattern.reserve(n_draft_min);
for (size_t j = cur_len - n_draft_min + 1; j < cur_len; ++j) {
pattern.push_back(tokens[j]);
}
pattern.push_back(sampled); // add the last token to the pattern
// We do a search in the token history.
state.idx_last_check = cur_len;
size_t match_pos = 0; // we ignore position 0, position 0 == no match
// search backwards, but skip the current match (we are currently there)
for (size_t j = cur_len - n_draft_min - 1; j > 0; --j) {
bool match = true;
for (size_t k = 0; k < pattern.size(); ++k) {
if (tokens[j + k] != pattern[k]) {
match = false;
break;
}
}
if (match) {
match_pos = j;
break;
}
}
if (match_pos == 0) {
return draft_tokens;
}
const size_t copy_max = std::min(
n_draft_max,
cur_len - (match_pos + n_draft_min)
);
if (copy_max < n_draft_min) {
return draft_tokens;
}
LOG_DBG("%s: #tokens = %zu: found matching pattern at pos %zu, length %zu, draft length %zu\n",
__func__, cur_len,
match_pos, pattern.size(), copy_max);
draft_tokens.reserve(copy_max);
for (size_t j = 0; j < copy_max; ++j) {
draft_tokens.push_back(tokens[match_pos + n_draft_min + j]);
}
return draft_tokens;
}
// n-gram map
//
// maximum number of counted values of a ngram map value.
#define COMMON_NGRAM_MAX_VALUE_COUNT 16380
static std::string common_tokens_to_str(const llama_tokens & inp, size_t start, size_t length);
void common_ngram_map_draft(common_ngram_map & map,
const llama_tokens & inp, llama_token sampled,
llama_tokens & draft) {
// reset last key and value.
map.last_draft_created = false;
map.last_draft_key_idx = 0;
map.last_draft_value_idx = 0;
const size_t cur_len = inp.size();
const uint16_t n = map.size_key;
const uint16_t m = map.size_value;
if (cur_len < static_cast<size_t>(2 * n + m)) {
return;
}
// Only check every check_rate tokens to save compute
// i.e., perform check if (cur_len - idx_last_check) >= check_rate
if (map.idx_last_check + map.check_rate > cur_len) {
return;
}
map.idx_last_check = cur_len;
// search pattern, the key n-gram
std::vector<llama_token> key_tokens;
key_tokens.reserve(n);
for (size_t j = cur_len - n + 1; j < cur_len; ++j) {
key_tokens.push_back(inp[j]);
}
key_tokens.push_back(sampled);
// search for the key in the map
size_t match_pos = 0;
for (size_t j = cur_len - n - m - 1; j > 0; --j) {
bool match = true;
for (size_t k = 0; k < n; ++k) {
if (inp[j + k] != key_tokens[k]) {
match = false;
break;
}
}
if (match) {
match_pos = j;
break;
}
}
if (match_pos > 0) {
LOG_INF("%s: cur_len = %zu, n = %d, m = %d, sz_tkns = %zu, sampled = %d, match_pos = %zu\n", __func__,
cur_len, n, m, key_tokens.size(), sampled, match_pos);
}
if (match_pos == 0) {
return;
}
// We have a match, now we look for the statistics of the key.
size_t key_offset = map.keys.size(); // offset in the map
// We iterate through the std::vector<common_ngram_map_key> map->keys.
for (size_t i = 0; i < map.keys.size(); ++i) {
bool match = true;
for (size_t j = 0; j < n; ++j) {
if (inp[map.keys[i].key_idx + j] != key_tokens[j]) {
match = false;
break;
}
}
if (match) {
key_offset = i;
break;
}
}
if (key_offset == map.keys.size()) {
// We create a new key-entry, it will get offset key_offset.
common_ngram_map_key new_key;
new_key.key_idx = match_pos;
new_key.stat_idx = 0;
new_key.key_num = 0;
for (int i = 0; i < COMMON_NGRAM_MAX_VALUES; ++i) {
new_key.values[i].value_num = 0;
new_key.values[i].n_accepted = m;
}
map.keys.push_back(new_key);
}
// our key n-gram:
common_ngram_map_key & curr_key = map.keys[key_offset];
// update number of key hits
curr_key.key_num = (uint16_t) std::min((int) map.keys[key_offset].key_num + 1,
(int) COMMON_NGRAM_MAX_VALUE_COUNT);
if (map.key_only) {
// simple mode:
// Fill in the draft with the m tokens following the key.
// We work with value values[0] only.
int n_draft_tokens = std::min((int) m, (int) curr_key.values[0].n_accepted);
for (int i = 0; i < n_draft_tokens; ++i) {
draft.push_back(inp[match_pos + n + i]);
}
LOG_INF("%s: key_offset = %zu, key_num = %d, draft.size = %zu\n", __func__,
key_offset, curr_key.key_num, draft.size());
map.last_draft_created = false;
map.last_draft_key_idx = key_offset;
map.last_draft_value_idx = 0; // value 0 is used for simple mode
return;
}
if (curr_key.key_num < map.min_hits) {
// not enough hits to consider this a good draft
LOG_DBG("%s: key_offset = %zu, key_num = %d, min_hits = %d, no draft\n", __func__,
key_offset, curr_key.key_num, map.min_hits);
return;
}
// complex mode: examine the different m-grams after this key n-gram.
//
// determine all (max COMMON_NGRAM_MAX_VALUES) m-grams after the key n-gram.
for (size_t i = curr_key.stat_idx; i <= match_pos; ++i) {
// begins the key n-gram at index i?
bool match_key = true;
for (size_t k = 0; k < n; ++k) {
if (inp[i + k] != key_tokens[k]) {
match_key = false;
break;
}
}
if (!match_key) {
continue;
}
// Do we haven a existing value m-gram or a new one after the key at index i?
size_t idx_begin_value_key = i + n;
int idx_value = -1;
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
size_t idx_begin_value_v = curr_key.values[v].value_idx;
if (idx_begin_value_v == 0) {
// We found an empty value slot => we found a new value m-gram after the key n-gram.
curr_key.values[v].value_idx = idx_begin_value_key;
curr_key.values[v].value_num = 0;
curr_key.values[v].n_accepted = m;
idx_value = v;
break;
}
bool match = true;
for (size_t j = 0; j < m; ++j) {
if (inp[idx_begin_value_key + j] != inp[idx_begin_value_v + j]) {
match = false;
break;
}
}
if (match) {
// We found an existing value m-gram after the key n-gram.
idx_value = v;
break;
}
}
if (idx_value >= 0) {
// We found a value m-gram of the key n-gram.
curr_key.values[idx_value].value_num = (uint16_t) std::min((int) curr_key.values[idx_value].value_num + 1,
(int) COMMON_NGRAM_MAX_VALUE_COUNT);
}
}
// the statistics are updated up to match_pos.
curr_key.stat_idx = match_pos;
// Do we have a value we could use for the draft?
uint16_t max_occur = 0;
int slot_max = 0;
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
uint16_t curr_occur = curr_key.values[v].value_num;
if (curr_occur > max_occur) {
max_occur = curr_occur;
slot_max = v;
}
}
// What is sum of the other occurences?
uint32_t sum_occur = 0;
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
if (v == slot_max) {
continue;
}
uint16_t curr_occur = curr_key.values[v].value_num;
sum_occur += curr_occur;
}
LOG_INF("%s: key_offset = %zu, max_occur = %d, sum_occur = %d, slot_max = %d [%zu/%d, %zu/%d, %zu/%d, %zu/%d]\n", __func__,
key_offset,
max_occur, sum_occur, slot_max,
curr_key.values[0].value_idx, curr_key.values[0].value_num,
curr_key.values[1].value_idx, curr_key.values[1].value_num,
curr_key.values[2].value_idx, curr_key.values[2].value_num,
curr_key.values[3].value_idx, curr_key.values[3].value_num
);
// Print the tokens of the four values (if idx != 0), use LOG_INF
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
if (curr_key.values[v].value_idx != 0) {
LOG_INF("%s: value[%d] = %s\n", __func__, v, common_tokens_to_str(inp, curr_key.values[v].value_idx, m).c_str());
}
}
if (sum_occur > 0 && max_occur < 3 * sum_occur) {
// The most frequent value is not much more frequent than the other values.
// We do not use the draft.
return;
}
// We use the most frequent value values[slot_max] for the draft.
// Fill in the draft with the m tokens following the key.
int n_draft_tokens = std::min((int) m, (int) curr_key.values[slot_max].n_accepted);
for (int i = 0; i < n_draft_tokens; ++i) {
draft.push_back(inp[match_pos + n + i]);
}
LOG_INF("%s: key_offset = %zu, slot_max = %d, key_num = %d, draft.size = %zu\n", __func__,
key_offset, slot_max,
curr_key.key_num, draft.size());
map.last_draft_created = true;
map.last_draft_key_idx = key_offset;
map.last_draft_value_idx = slot_max; // value used for draft generation.
}
void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted) {
if (!map.last_draft_created) {
return;
}
// find the key and its chosen value.
const size_t key_idx = map.last_draft_key_idx;
const size_t val_idx = map.last_draft_value_idx;
// find key corresponding to key_idx.
common_ngram_map_key & curr_key = map.keys[key_idx];
// find value corresponding to val_idx.
struct common_ngram_map_value & curr_value = curr_key.values[val_idx]; // value used for draft generation.
// update the value statistics
LOG_INF("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n",
n_accepted, curr_value.n_accepted);
curr_value.n_accepted = n_accepted;
}
// Helper functions.
//
// Print the values of a sublist of `llama_tokens & inp` to a string in the form [v0, v1, v2, ...].
std::string common_tokens_to_str(const llama_tokens & inp, size_t start, size_t length) {
std::ostringstream oss;
oss << '[';
for (size_t i = 0; i < length; ++i) {
if (i > 0) {
oss << ", ";
}
oss << inp[start + i];
}
oss << ']';
return oss.str();
}

105
common/ngram-map.h Normal file
View File

@ -0,0 +1,105 @@
#pragma once
//
// common/ngram-map.h: structures used to manage a map from n-grams to a list of m-grams
//
// These structures are used to do a lookup of n-grams followed by m-grams in token history.
//
// There are two algorithms implemented:
// 1. ngram_simple: lookup of n-grams followed by m-grams in token history.
// 2. ngram_map: lookup of n-grams followed by m-grams in token history using a map.
// The map is a vector of key n-grams, and for each key n-gram there is a list of value m-grams.
//
#include "llama.h"
#include <vector>
// n-gram simple
//
// config of n-gram simple.
struct common_ngram_simple_config {
uint16_t size_ngram; // size of n-grams to lookup in self-mode
uint16_t size_mgram; // size of m-grams to draft in self-mode
uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token
};
// current state (and config) of n-gram simple.
struct common_ngram_simple_state {
common_ngram_simple_config config;
size_t idx_last_check = 0; // index of last check in context history (mutable)
common_ngram_simple_state(const common_ngram_simple_config & config)
: config(config) {}
};
// Searches for a n-gram in the history and checks whether a draft sequence should be generated.
// state: the ngram simple state to search in.
// inp: the tokens generated so far.
// sampled: the token that was just sampled.
// draft: vector to store the draft tokens, initially empty.
llama_tokens common_ngram_simple_draft(
common_ngram_simple_state & state,
const llama_tokens & tokens, llama_token sampled);
// n-gram map
//
// maximum number of m-gram values stored for each key n-gram.
#define COMMON_NGRAM_MAX_VALUES 4
// statistics of a m-gram after a known n-gram
struct common_ngram_map_value {
size_t value_idx = 0; // index of value m-gram in token-history (0 if unused)
uint16_t value_num = 0; // number of occurences of this value m-gram after the key n-gram (0 in an unused values-slot)
int16_t n_accepted = -1; // number of accepted tokens at last draft (-1 if unused)
};
// statistics of a n-gram
struct common_ngram_map_key {
size_t key_idx; // index of key n-gram in token-history
size_t stat_idx; // index of last token of stastistics computation (key_num, values)
uint16_t key_num; // number of occurences of this key n-gram in token-history
common_ngram_map_value values[COMMON_NGRAM_MAX_VALUES]; // some known values after the key
};
// map from n-grams to following m-grams in token-history
struct common_ngram_map {
uint16_t size_key; // size of key n-grams
uint16_t size_value; // size of value m-grams
bool key_only; // true if only key n-grams are used, no values.
// first draft: vector only, no map.
std::vector<common_ngram_map_key> keys; // key n-grams which occur several times in token-history
uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token
uint16_t min_hits; // minimum number of key hits to consider a draft
common_ngram_map(uint16_t sz_key, uint16_t sz_value, bool only_keys,
uint16_t check_rate, uint16_t min_hits)
: size_key(sz_key), size_value(sz_value), key_only(only_keys),
check_rate(check_rate), min_hits(min_hits) {}
bool last_draft_created = false; // true if a draft was created at last call.
size_t last_draft_key_idx = 0; // index of last key used for draft generation.
uint16_t last_draft_value_idx = 0; // index of last value used for draft generation.
size_t idx_last_check = 0; // index of last check in context history
};
// Searches for the n-gram in the history and checks whether a draft sequence should be generated.
// map: the ngram map to search in.
// inp: the tokens generated so far.
// sampled: the token that was just sampled.
// draft: vector to store the draft tokens, initially empty.
void common_ngram_map_draft(
common_ngram_map & map,
const llama_tokens & inp, llama_token sampled,
llama_tokens & draft);
// Update the statistics of a value after a draft was processed.
void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted);

File diff suppressed because it is too large Load Diff

View File

@ -5,31 +5,33 @@
struct common_speculative;
struct common_speculative_params {
int n_draft = 16; // max drafted tokens
int n_reuse = 256;
// comma separated list of all types
std::string common_speculative_type_name_str();
float p_min = 0.75f; // min probability required to accept a token in the draft
};
// convert string to type
enum common_speculative_type common_speculative_type_from_name(const std::string & name);
struct common_speculative * common_speculative_init(
struct llama_context * ctx_tgt,
struct llama_context * ctx_dft
);
// convert type to string
std::string common_speculative_type_to_str(enum common_speculative_type type);
void common_speculative_free(struct common_speculative * spec);
common_speculative * common_speculative_init(
const common_params_speculative & params,
llama_context * ctx_tgt);
bool common_speculative_are_compatible(
const struct llama_context * ctx_tgt,
const struct llama_context * ctx_dft);
void common_speculative_free(common_speculative * spec);
void common_speculative_add_replacement_tgt_dft(
struct common_speculative * spec,
const char *source, const char *dest);
// optionally call once at the beginning of a new generation
void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt);
// sample up to n_draft tokens and add them to the batch using the draft model
llama_tokens common_speculative_gen_draft(
struct common_speculative * spec,
struct common_speculative_params params,
const llama_tokens & prompt,
llama_token id_last);
llama_tokens common_speculative_draft(
common_speculative * spec,
const common_params_speculative & params,
const llama_tokens & prompt,
llama_token id_last);
// informs the speculative decoder that n_accepted tokens were accepted by the target model
void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);
// print statistics about the speculative decoding
void common_speculative_print_stats(const common_speculative * spec);

View File

@ -2736,7 +2736,7 @@ class AfmoeModel(LlamaModel):
data_torch = torch.stack(datas, dim=0)
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
yield from super().modify_tensors(data_torch, merged_name, bid)
yield from ModelBase.modify_tensors(self, data_torch, merged_name, bid)
return
else:
@ -2745,7 +2745,7 @@ class AfmoeModel(LlamaModel):
if name.endswith(".expert_bias"):
name = name.replace(".expert_bias", ".expert_bias.bias")
yield from super().modify_tensors(data_torch, name, bid)
yield from ModelBase.modify_tensors(self, data_torch, name, bid)
@ModelBase.register(
@ -3799,7 +3799,7 @@ class Ernie4_5MoeModel(Ernie4_5Model):
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
yield from super().modify_tensors(data_torch, merged_name, bid)
else:
yield from super().modify_tensors(data_torch, name, bid)
yield from ModelBase.modify_tensors(self, data_torch, name, bid)
def prepare_tensors(self):
super().prepare_tensors()
@ -6145,7 +6145,8 @@ class Gemma3nVisionAudioModel(ConformerAudioModel):
if name.startswith("model.vision_tower.timm_model.blocks."):
# Double-indexed block tensors through custom logic
new_name = self.custom_map(name)
yield (self.custom_map(name), data_torch)
return
else:
# Route non-repeating (conv_stem, msfa, embedding, etc.) and un-catched through tensor_mapping.py
new_name = self.map_tensor_name(name)
@ -6153,7 +6154,7 @@ class Gemma3nVisionAudioModel(ConformerAudioModel):
if new_name.endswith("conv_stem.conv.bias") or new_name.endswith("layer_scale.gamma"):
data_torch = data_torch.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) # [1, C, 1, 1]
yield from super().modify_tensors(data_torch, new_name, bid)
yield from ModelBase.modify_tensors(self, data_torch, new_name, bid)
@ModelBase.register("Gemma3nForCausalLM", "Gemma3nForConditionalGeneration")
@ -6253,7 +6254,7 @@ class Gemma3NModel(Gemma3Model):
# Continue with normal processing
name = name.replace("language_model.", "")
yield from super().modify_tensors(data_torch, name, bid)
yield from ModelBase.modify_tensors(self, data_torch, name, bid)
return
if "altup_unembed_projections" in name:
@ -6270,7 +6271,7 @@ class Gemma3NModel(Gemma3Model):
raise ValueError(f"Unknown name: {name}")
out = self._stack_matrices(self._altup_unembd)
if out is not None:
yield from super().modify_tensors(out, "model.altup_unembed_projections.weight", bid)
yield from ModelBase.modify_tensors(self, out, "model.altup_unembed_projections.weight", bid)
return
else:
return
@ -6287,7 +6288,7 @@ class Gemma3NModel(Gemma3Model):
raise ValueError(f"Unknown name: {name}")
out = self._stack_matrices(self._altup_proj)
if out is not None:
yield from super().modify_tensors(out, "model.altup_projections.weight", bid)
yield from ModelBase.modify_tensors(self, out, "model.altup_projections.weight", bid)
return
else:
return
@ -8803,8 +8804,8 @@ class GraniteMoeModel(GraniteModel):
ffn_dim = self.hparams["intermediate_size"]
assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * intermediate_size"
gate, up = data_torch.split(ffn_dim, dim=-2)
yield from super().modify_tensors(gate, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_EXP, bid), bid)
yield from super().modify_tensors(up, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), bid)
yield from ModelBase.modify_tensors(self, gate, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_EXP, bid), bid)
yield from ModelBase.modify_tensors(self, up, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), bid)
has_experts = bool(self.hparams.get('num_local_experts'))
@ -8813,15 +8814,15 @@ class GraniteMoeModel(GraniteModel):
assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * shared_intermediate_size"
gate, up = data_torch.split(ffn_dim, dim=-2)
if has_experts:
yield from super().modify_tensors(gate,self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SHEXP, bid), bid)
yield from super().modify_tensors(up, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SHEXP, bid), bid)
yield from ModelBase.modify_tensors(self, gate,self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SHEXP, bid), bid)
yield from ModelBase.modify_tensors(self, up, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SHEXP, bid), bid)
return
yield from super().modify_tensors(gate, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), bid)
yield from super().modify_tensors(up, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), bid)
yield from ModelBase.modify_tensors(self, gate, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), bid)
yield from ModelBase.modify_tensors(self, up, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), bid)
return
if not has_experts and name.endswith("shared_mlp.output_linear.weight"):
yield from super().modify_tensors(data_torch, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN, bid), bid)
yield from ModelBase.modify_tensors(self, data_torch, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN, bid), bid)
return
yield from super().modify_tensors(data_torch, name, bid)
@ -8911,14 +8912,17 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel):
name.endswith("block_sparse_moe.input_linear.weight")
or "shared_mlp" in name
):
return GraniteMoeModel.modify_tensors(self, data_torch, name, bid)
yield from GraniteMoeModel.modify_tensors(self, data_torch, name, bid)
return
# Determine whether this is a mamba layer or an attention layer
if bid in self._ssm_layers:
return Mamba2Model.modify_tensors(self, data_torch, name, bid)
yield from Mamba2Model.modify_tensors(self, data_torch, name, bid)
return
elif bid in self._attn_layers:
return GraniteMoeModel.modify_tensors(self, data_torch, name, bid)
yield from super().modify_tensors(data_torch, name, bid)
yield from GraniteMoeModel.modify_tensors(self, data_torch, name, bid)
return
yield from ModelBase.modify_tensors(self, data_torch, name, bid)
def set_gguf_parameters(self):
"""This method merges params from both parents and some that are
@ -9050,33 +9054,33 @@ class NemotronHModel(GraniteHybridModel):
if self.is_moe and bid is not None:
if name.endswith("mixer.gate.e_score_correction_bias"):
new_name = name.replace("e_score_correction_bias", "e_score_correction.bias")
yield from super().modify_tensors(data_torch, new_name, bid)
yield from ModelBase.modify_tensors(self, data_torch, new_name, bid)
return
if name.endswith("mixer.dt_bias"):
new_name = name.replace("dt_bias", "dt.bias")
yield from super().modify_tensors(data_torch, new_name, bid)
yield from ModelBase.modify_tensors(self, data_torch, new_name, bid)
return
if name.endswith("mixer.conv1d.weight"):
squeezed_data = data_torch.squeeze()
yield from super().modify_tensors(squeezed_data, name, bid)
yield from ModelBase.modify_tensors(self, squeezed_data, name, bid)
return
if name.endswith("mixer.A_log"):
transformed_data = -torch.exp(data_torch)
reshaped_data = transformed_data.squeeze().reshape(-1, 1)
yield from super().modify_tensors(reshaped_data, name, bid)
yield from ModelBase.modify_tensors(self, reshaped_data, name, bid)
return
if name.endswith("mixer.D"):
reshaped_data = data_torch.squeeze().reshape(-1, 1)
yield from super().modify_tensors(reshaped_data, name, bid)
yield from ModelBase.modify_tensors(self, reshaped_data, name, bid)
return
if name.endswith("mixer.norm.weight"):
reshaped_data = data_torch.reshape(self.n_group, -1)
yield from super().modify_tensors(reshaped_data, name, bid)
yield from ModelBase.modify_tensors(self, reshaped_data, name, bid)
return
if name.find("mixer.experts") != -1:
@ -9101,7 +9105,7 @@ class NemotronHModel(GraniteHybridModel):
data_torch = torch.stack(datas, dim=0)
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
yield from super().modify_tensors(data_torch, merged_name, bid)
yield from ModelBase.modify_tensors(self, data_torch, merged_name, bid)
return
else:
return
@ -10731,7 +10735,7 @@ class CogVLMModel(LlamaModel):
if name.startswith("model.vision."):
return
yield from super().modify_tensors(data_torch, name, bid)
yield from ModelBase.modify_tensors(self, data_torch, name, bid)
@ModelBase.register("JanusForConditionalGeneration")

View File

@ -144,7 +144,7 @@ We also have a [guide](./backend/CUDA-FEDORA.md) for setting up CUDA toolkit in
- ***Necessary*** for users of [Atomic Desktops for Fedora](https://fedoraproject.org/atomic-desktops/); such as: [Silverblue](https://fedoraproject.org/atomic-desktops/silverblue/) and [Kinoite](https://fedoraproject.org/atomic-desktops/kinoite/).
- (there are no supported CUDA packages for these systems)
- ***Necessary*** for users that have a host that is not a: [Supported Nvidia CUDA Release Platform](https://developer.nvidia.com/cuda-downloads).
- (for example, you may have [Fedora 42 Beta](https://fedoramagazine.org/announcing-fedora-linux-42-beta/) as your your host operating system)
- (for example, you may have [Fedora 42 Beta](https://fedoramagazine.org/announcing-fedora-linux-42-beta/) as your host operating system)
- ***Convenient*** For those running [Fedora Workstation](https://fedoraproject.org/workstation/) or [Fedora KDE Plasma Desktop](https://fedoraproject.org/spins/kde), and want to keep their host system clean.
- *Optionally* toolbox packages are available: [Arch Linux](https://archlinux.org/), [Red Hat Enterprise Linux >= 8.5](https://www.redhat.com/en/technologies/linux-platforms/enterprise-linux), or [Ubuntu](https://ubuntu.com/download)
@ -248,6 +248,14 @@ You may set the [cuda environmental variables](https://docs.nvidia.com/cuda/cuda
CUDA_VISIBLE_DEVICES="-0" ./build/bin/llama-server --model /srv/models/llama.gguf
```
#### CUDA_SCALE_LAUNCH_QUEUES
The environment variable [`CUDA_SCALE_LAUNCH_QUEUES`](https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/environment-variables.html#cuda-scale-launch-queues) controls the size of CUDA's command buffer, which determines how many GPU operations can be queued before the CPU must wait for the GPU to catch up. A larger buffer reduces CPU-side stalls and allows more work to be queued on a GPU.
**Default behavior:** llama.cpp automatically sets `CUDA_SCALE_LAUNCH_QUEUES=4x`, which increases the CUDA command buffer to 4 times its default size. This optimization is particularly beneficial for **Multi-GPU setups with pipeline parallelism**, where it significantly improves prompt processing throughput by allowing more operations to be enqueued across GPUs.
See PR [#19042](https://github.com/ggml-org/llama.cpp/pull/19042) for performance benchmarks and technical details.
### Unified Memory
The environment variable `GGML_CUDA_ENABLE_UNIFIED_MEMORY=1` can be used to enable unified memory in Linux. This allows swapping to system RAM instead of crashing when the GPU VRAM is exhausted. In Windows this setting is available in the NVIDIA control panel as `System Memory Fallback`.
@ -487,6 +495,37 @@ Finally, after finishing your build, you should be able to do something like thi
# ggml_vulkan: Using Intel(R) Graphics (ADL GT2) | uma: 1 | fp16: 1 | warp size: 32
```
### For Mac users:
Generally, follow LunarG's [Getting Started with the MacOS Vulkan SDK](https://vulkan.lunarg.com/doc/sdk/latest/mac/getting_started.html) guide for installation and setup of the Vulkan SDK. There are two options of Vulkan drivers on macOS, both of which implement translation layers to map Vulkan to Metal. They can be hot-swapped by setting the `VK_ICD_FILENAMES` environment variable to point to the respective ICD JSON file.
Check the box for "KosmicKrisp" during the LunarG Vulkan SDK installation.
Set environment variable for the LunarG Vulkan SDK after installation (and optionally add to your shell profile for persistence):
```bash
source /path/to/vulkan-sdk/setup-env.sh
```
#### Using MoltenVK
MoltenVK is the default Vulkan driver installed with the LunarG Vulkan SDK on macOS, so you can use the above environment variable settings as is.
#### Using KosmicKrisp
Override the environment variable for KosmicKrisp:
```bash
export VK_ICD_FILENAMES=$VULKAN_SDK/share/vulkan/icd.d/libkosmickrisp_icd.json
export VK_DRIVER_FILES=$VULKAN_SDK/share/vulkan/icd.d/libkosmickrisp_icd.json
```
#### Build
This is the only step different from [above](#common-steps) instructions.
```bash
cmake -B build -DGGML_VULKAN=1 -DGGML_METAL=OFF
cmake --build build --config Release
```
## CANN
This provides NPU acceleration using the AI cores of your Ascend NPU. And [CANN](https://www.hiascend.com/en/software/cann) is a hierarchical APIs to help you to quickly build AI applications and service based on Ascend NPU.

File diff suppressed because it is too large Load Diff

120
docs/speculative.md Normal file
View File

@ -0,0 +1,120 @@
# Speculative Decoding
llama.cpp supports speculative decoding, a technique that can significantly accelerate token generation by predicting multiple tokens ahead of the main model.
[Speculative decoding](https://en.wikipedia.org/wiki/Transformer_(deep_learning)#Speculative_decoding) leverages the fact that computing n tokens in a batch (as in prompt processing) is more efficient than computing n sequentially (as in response generation). By generating draft tokens quickly and then verifying them with the target model in a single batch, this approach can achieve substantial speedups when the draft predictions are frequently correct.
## Implementations
The `llama-server` application supports several implementations of speculative decoding:
### Draft Model (`draft`)
A much smaller model (called the _draft model_) generates drafts.
A draft model is the most used approach in speculative decoding.
### n-gram Cache (`ngram-cache`)
An n-gram is a sequence of n tokens. The n-gram cache implementation maintains statistics about short n-gram sequences.
A draft is computed using probabilities derived from these statistics. External statistics can also be loaded from files for improved accuracy.
See:
- #5479, #6828, #6848
### n-gram Map (`ngram-simple`, `ngram-map-*`)
These implementations search the token history for patterns and use matching sequences as draft candidates.
They require no additional model but rely on patterns that have already appeared in the generated text.
An example to use this approach can be the rewriting of source code by a LLM.
#### n-gram Map (`ngram-simple`)
This implementation looks for the last n-gram in history that matches the current n-gram and creates a draft using the m tokens following the matched n-gram. It is the simplest self-speculative approach with minimal overhead.
#### n-gram Map Key (`ngram-map-k`)
This implementation looks for the current n-gram of size n (called the _key_) in the token history. If the key n-gram is followed by the same m tokens (called the _mgram_) multiple times, it creates a draft using these m tokens. This approach requires a minimum number of occurrences (argument `--spec-ngram-min-hits`) before generating drafts.
The number of accepted tokens is stored for each used n-gram.
#### n-gram Map Key-4-Values (`ngram-map-k4v`)
This experimental implementation looks for the current n-gram of size n (called the _key_) in the token history. For each key, up to four _values_ (n-grams of size m, called _mgrams_) are tracked. An internal statistic counts the occurrences of each mgram after the key n-gram. If one mgram is significantly more frequent than the others, it is used as the draft.
The number of accepted tokens is stored for each used n-gram.
**Example:** Server options to be used if there are a lot of longer repetitions.
```bash
llama-server [...] --spec-type ngram-map-k4v --spec-ngram-size-n 8 --spec-ngram-size-m 8 --spec-ngram-min-hits 2
```
## Command-Line Options
If a draft model is combined with a draftless decoding the draftless decoding has higher precedence.
```
--spec-type [none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v]
type of speculative decoding to use when no draft model is provided
(default: none)
--spec-ngram-size-n N ngram size N for ngram-simple/ngram-map speculative decoding, length
of lookup n-gram (default: 12)
--spec-ngram-size-m N ngram size M for ngram-simple/ngram-map speculative decoding, length
of draft m-gram (default: 48)
--spec-ngram-check-rate N ngram check rate for ngram-simple/ngram-map speculative decoding
(default: 1)
--spec-ngram-min-hits N minimum hits for ngram-map speculative decoding (default: 1)
```
### `--spec-type TYPE`
Specifies a type of speculative decoding without draft model.
| Type | Description |
|------|-------------|
| `none` | No speculative decoding (default) |
| `ngram-cache` | Use n-gram cache lookup |
| `ngram-simple` | Use simple n-gram pattern matching |
| `ngram-map-k` | Use n-gram pattern matching with n-gram-keys |
| `ngram-map-k4v` | Use n-gram pattern matching with n-gram-keys and up to four m-gram values (experimental) |
**Example:** Server-instance used to refactor source code.
```bash
./llama-server [...] --spec-type ngram-simple
```
### `--spec-ngram-size-n N`
Sets the size N of the lookup n-gram for n-gram map based speculative decoding.
The n-gram size N determines how many tokens in a row to look back when searching for matching patterns.
### `--spec-ngram-size-m M`
Sets the size M of the draft m-gram for n-gram map based speculative decoding.
The m-gram size determines how many tokens to draft when a match is found.
Larger values can provide more speedup but may reduce acceptance rate.
### `--spec-ngram-check-rate R`
This option aims at performance if the n-gram lookup in history is to costly. A lookup will be executed at every R tokens (default is 1, every token).
### `--spec-ngram-min-hits H`
This option defines how often a key has to appear in the token history to be used as a draft (default is 1).
## Statistics
Each speculative decoding implementation prints statistics.
```
draft acceptance rate = 0.57576 ( 171 accepted / 297 generated)
statistics ngram_simple: #calls = 15, #gen drafts = 5, #acc drafts = 5, #gen tokens = 187, #acc tokens = 73
statistics draft: #calls = 10, #gen drafts = 10, #acc drafts = 10, #gen tokens = 110, #acc tokens = 98
```
- `#calls`: number of calls of this implementations
- `#gen drafts`: number of drafts generated by this implementation
- `#acc drafts`: number of drafts accepted (partially) by the main model
- `#gen tokens`: number of tokens generated by this implementation (including rejected tokens)
- `#acc tokens`: number of tokens accepted by the main model

View File

@ -32,9 +32,9 @@ int main(int argc, char ** argv){
common_ngram_cache ngram_cache;
common_ngram_cache_update(ngram_cache, LLAMA_NGRAM_STATIC, LLAMA_NGRAM_STATIC, inp, inp.size(), true);
fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.lookup_cache_static.c_str());
fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.speculative.lookup_cache_static.c_str());
common_ngram_cache_save(ngram_cache, params.lookup_cache_static);
common_ngram_cache_save(ngram_cache, params.speculative.lookup_cache_static);
return 0;
}

View File

@ -46,18 +46,18 @@ int main(int argc, char ** argv){
{
const int64_t t_start_draft_us = ggml_time_us();
if (!params.lookup_cache_static.empty()) {
if (!params.speculative.lookup_cache_static.empty()) {
try {
ngram_cache_static = common_ngram_cache_load(params.lookup_cache_static);
ngram_cache_static = common_ngram_cache_load(params.speculative.lookup_cache_static);
} catch (std::ifstream::failure const &) {
LOG_ERR("failed to open static lookup cache: %s", params.lookup_cache_static.c_str());
LOG_ERR("failed to open static lookup cache: %s", params.speculative.lookup_cache_static.c_str());
exit(1);
}
}
if (!params.lookup_cache_dynamic.empty()) {
if (!params.speculative.lookup_cache_dynamic.empty()) {
try {
ngram_cache_dynamic = common_ngram_cache_load(params.lookup_cache_dynamic);
ngram_cache_dynamic = common_ngram_cache_load(params.speculative.lookup_cache_dynamic);
} catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program
}

View File

@ -51,18 +51,18 @@ int main(int argc, char ** argv){
const int64_t t_start_draft_us = ggml_time_us();
common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, inp.size(), false);
if (!params.lookup_cache_static.empty()) {
if (!params.speculative.lookup_cache_static.empty()) {
try {
ngram_cache_static = common_ngram_cache_load(params.lookup_cache_static);
ngram_cache_static = common_ngram_cache_load(params.speculative.lookup_cache_static);
} catch (std::ifstream::failure const &) {
LOG_ERR("failed to open static lookup cache: %s", params.lookup_cache_static.c_str());
LOG_ERR("failed to open static lookup cache: %s", params.speculative.lookup_cache_static.c_str());
exit(1);
}
}
if (!params.lookup_cache_dynamic.empty()) {
if (!params.speculative.lookup_cache_dynamic.empty()) {
try {
ngram_cache_dynamic = common_ngram_cache_load(params.lookup_cache_dynamic);
ngram_cache_dynamic = common_ngram_cache_load(params.speculative.lookup_cache_dynamic);
} catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program
}
@ -210,7 +210,7 @@ int main(int argc, char ** argv){
// Update dynamic ngram cache with context ngram cache and save it to disk:
common_ngram_cache_merge(ngram_cache_dynamic, ngram_cache_context);
common_ngram_cache_save(ngram_cache_dynamic, params.lookup_cache_dynamic);
common_ngram_cache_save(ngram_cache_dynamic, params.speculative.lookup_cache_dynamic);
LOG("\n\n");

View File

@ -24,7 +24,7 @@ int main(int argc, char ** argv) {
common_init();
if (params.speculative.model.path.empty()) {
if (params.speculative.mparams_dft.path.empty()) {
LOG_ERR("%s: --model-draft is required\n", __func__);
return 1;
}
@ -34,10 +34,8 @@ int main(int argc, char ** argv) {
llama_numa_init(params.numa);
llama_model * model_tgt = NULL;
//llama_model * model_dft = NULL;
llama_context * ctx_tgt = NULL;
llama_context * ctx_dft = NULL;
// load the target model
auto llama_init_tgt = common_init_from_params(params);
@ -48,26 +46,38 @@ int main(int argc, char ** argv) {
const llama_vocab * vocab = llama_model_get_vocab(model_tgt);
// load the draft model
params.devices = params.speculative.devices;
params.model = params.speculative.model;
params.n_ctx = params.speculative.n_ctx;
params.n_batch = params.speculative.n_ctx > 0 ? params.speculative.n_ctx : params.n_batch;
params.n_gpu_layers = params.speculative.n_gpu_layers;
llama_model_ptr model_dft;
if (params.speculative.cpuparams.n_threads > 0) {
params.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
}
// TODO: simplify this logic
{
const auto & params_spec = params.speculative;
params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
params.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
auto params_dft = params;
auto llama_init_dft = common_init_from_params(params);
params_dft.n_parallel = 1;
params_dft.n_ctx = params_spec.n_ctx;
params_dft.n_batch = llama_n_ctx_seq(ctx_tgt);
params_dft.devices = params_spec.devices;
params_dft.model = params_spec.mparams_dft;
params_dft.n_gpu_layers = params_spec.n_gpu_layers;
//model_dft = llama_init_dft->model();
ctx_dft = llama_init_dft->context();
if (params_spec.cpuparams.n_threads > 0) {
params_dft.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
params_dft.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
}
if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) {
LOG_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params.speculative.model.path.c_str(), params.model.path.c_str());
params_dft.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
auto mparams_dft = common_model_params_to_llama(params_dft);
model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft));
if (model_dft == nullptr) {
LOG_ERR("failed to load draft model, '%s'\n", params_dft.model.path.c_str());
return 1;
}
params.speculative.model_dft = model_dft.get();
params.speculative.cparams_dft = common_context_params_to_llama(params_dft);
}
// Tokenize the prompt
@ -92,12 +102,6 @@ int main(int argc, char ** argv) {
LOG("%s", common_token_to_piece(ctx_tgt, id).c_str());
}
// how many tokens to draft each time
int n_draft = params.speculative.n_max;
int n_draft_min = params.speculative.n_min;
float p_min = params.speculative.p_min;
int n_predict = 0;
int n_drafted = 0;
int n_accept = 0;
@ -127,15 +131,11 @@ int main(int argc, char ** argv) {
int n_past = inp.size() - 1;
// init the speculator
struct common_speculative_params params_spec;
params_spec.n_draft = n_draft;
params_spec.n_reuse = llama_n_ctx(ctx_dft) - n_draft;
params_spec.p_min = p_min;
const auto & params_spec = params.speculative;
struct common_speculative * spec = common_speculative_init(ctx_tgt, ctx_dft);
for (auto &pair : params.speculative.replacements) {
common_speculative_add_replacement_tgt_dft(spec, pair.first.c_str(), pair.second.c_str());
}
struct common_speculative * spec = common_speculative_init(params.speculative, ctx_tgt);
common_speculative_begin(spec, prompt_tgt);
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);
@ -151,7 +151,7 @@ int main(int argc, char ** argv) {
// offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens
// from a cache or lookup tables.
//
llama_tokens draft = common_speculative_gen_draft(spec, params_spec, prompt_tgt, id_last);
llama_tokens draft = common_speculative_draft(spec, params_spec, prompt_tgt, id_last);
//LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str());
@ -162,7 +162,7 @@ int main(int argc, char ** argv) {
// evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
{
// do not waste time on small drafts
if (draft.size() < (size_t) n_draft_min) {
if (draft.size() < (size_t) params_spec.n_min) {
draft.clear();
}
@ -240,7 +240,7 @@ int main(int argc, char ** argv) {
LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
LOG_INF("\n");
LOG_INF("n_draft = %d\n", n_draft);
LOG_INF("n_draft = %d\n", params_spec.n_max);
LOG_INF("n_predict = %d\n", n_predict);
LOG_INF("n_drafted = %d\n", n_drafted);
LOG_INF("n_accept = %d\n", n_accept);
@ -249,8 +249,6 @@ int main(int argc, char ** argv) {
LOG_INF("\n");
LOG_INF("draft:\n\n");
llama_perf_context_print(ctx_dft);
LOG_INF("\n");
LOG_INF("target:\n\n");
common_perf_print(ctx_tgt, smpl);

View File

@ -46,7 +46,7 @@ int main(int argc, char ** argv) {
common_init();
if (params.speculative.model.path.empty()) {
if (params.speculative.mparams_dft.path.empty()) {
LOG_ERR("%s: --model-draft is required\n", __func__);
return 1;
}
@ -78,7 +78,7 @@ int main(int argc, char ** argv) {
// load the draft model
params.devices = params.speculative.devices;
params.model = params.speculative.model;
params.model = params.speculative.mparams_dft;
params.n_gpu_layers = params.speculative.n_gpu_layers;
if (params.speculative.cpuparams.n_threads > 0) {
params.cpuparams.n_threads = params.speculative.cpuparams.n_threads;

View File

@ -228,6 +228,8 @@ option(GGML_WEBGPU_CPU_PROFILE "ggml: enable WebGPU profiling (CPU)
option(GGML_WEBGPU_GPU_PROFILE "ggml: enable WebGPU profiling (GPU)" OFF)
option(GGML_WEBGPU_JSPI "ggml: use JSPI for WebGPU" ON)
option(GGML_ZDNN "ggml: use zDNN" OFF)
option(GGML_VIRTGPU "ggml: use the VirtGPU/Virglrenderer API Remoting frontend" OFF)
option(GGML_VIRTGPU_BACKEND "ggml: build the VirtGPU/Virglrenderer API Remoting backend" OFF)
option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT})
option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF)
option(GGML_METAL_SHADER_DEBUG "ggml: compile Metal with -fno-fast-math" OFF)
@ -320,6 +322,7 @@ set(GGML_PUBLIC_HEADERS
include/ggml-opt.h
include/ggml-metal.h
include/ggml-rpc.h
include/ggml-virtgpu.h
include/ggml-sycl.h
include/ggml-vulkan.h
include/ggml-webgpu.h

View File

@ -0,0 +1,16 @@
#pragma once
#include "ggml.h"
#include "ggml-backend.h"
#ifdef __cplusplus
extern "C" {
#endif
#define GGML_REMOTING_FRONTEND_NAME "RemotingFrontend"
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_virtgpu_reg();
#ifdef __cplusplus
}
#endif

View File

@ -451,6 +451,7 @@ ggml_add_backend(HIP)
ggml_add_backend(METAL)
ggml_add_backend(MUSA)
ggml_add_backend(RPC)
ggml_add_backend(VirtGPU)
ggml_add_backend(SYCL)
ggml_add_backend(Vulkan)
ggml_add_backend(WebGPU)

View File

@ -69,6 +69,10 @@
#include "ggml-rpc.h"
#endif
#ifdef GGML_USE_VIRTGPU_FRONTEND
#include "ggml-virtgpu.h"
#endif
#ifdef GGML_USE_CANN
#include "ggml-cann.h"
#endif
@ -180,7 +184,12 @@ struct ggml_backend_registry {
register_backend(ggml_backend_sycl_reg());
#endif
#ifdef GGML_USE_VULKAN
// Add runtime disable check
if (getenv("GGML_DISABLE_VULKAN") == nullptr) {
register_backend(ggml_backend_vk_reg());
} else {
GGML_LOG_DEBUG("Vulkan backend disabled by GGML_DISABLE_VULKAN environment variable\n");
}
#endif
#ifdef GGML_USE_WEBGPU
register_backend(ggml_backend_webgpu_reg());
@ -188,6 +197,10 @@ struct ggml_backend_registry {
#ifdef GGML_USE_ZDNN
register_backend(ggml_backend_zdnn_reg());
#endif
#ifdef GGML_USE_VIRTGPU_FRONTEND
register_backend(ggml_backend_virtgpu_reg());
#endif
#ifdef GGML_USE_OPENCL
register_backend(ggml_backend_opencl_reg());
#endif
@ -604,6 +617,7 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
ggml_backend_load_best("rpc", silent, dir_path);
ggml_backend_load_best("sycl", silent, dir_path);
ggml_backend_load_best("vulkan", silent, dir_path);
ggml_backend_load_best("virtgpu", silent, dir_path);
ggml_backend_load_best("opencl", silent, dir_path);
ggml_backend_load_best("hexagon", silent, dir_path);
ggml_backend_load_best("musa", silent, dir_path);

View File

@ -1,3 +1,4 @@
#pragma once
// Rename `_generic` functions if no native implementation is available.
@ -42,6 +43,7 @@
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
@ -53,6 +55,7 @@
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
# define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
@ -73,6 +76,7 @@
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
@ -80,6 +84,7 @@
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
@ -102,6 +107,7 @@
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
@ -113,6 +119,7 @@
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
@ -136,6 +143,7 @@
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
@ -147,6 +155,7 @@
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
@ -177,6 +186,7 @@
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
@ -187,6 +197,7 @@
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
@ -216,6 +227,7 @@
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
@ -227,6 +239,7 @@
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
@ -258,6 +271,7 @@
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
@ -269,6 +283,7 @@
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0

View File

@ -1055,10 +1055,10 @@ void ggml_gemv_q5_K_8x8_q8_K(int n,
// FUSED BIAS: Compute and subtract bias immediately
// bias = (bsums_lo * mins_lo + bsums_hi * mins_hi) * sb_min
int32x4_t bias = vmull_s16(bsums_vec_lo, group_mins_lo);
bias = vmlal_s16(bias, bsums_vec_hi, group_mins_hi);
int32x4_t bias = vmull_s16(bsums_vec_lo, group_mins_lo);
bias = vmlal_s16(bias, bsums_vec_hi, group_mins_hi);
float32x4_t bias_f32 = vcvtq_f32_s32(bias);
acc_f32[i] = vmlsq_f32(acc_f32[i], sb_min, bias_f32);
acc_f32[i] = vmlsq_f32(acc_f32[i], sb_min, bias_f32);
}
} // for sb
} // for b
@ -1072,6 +1072,208 @@ void ggml_gemv_q5_K_8x8_q8_K(int n,
ggml_gemv_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemv_q6_K_8x8_q8_K(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
constexpr int qk = QK_K;
const int nb = n / qk;
constexpr int ncols_interleaved = 8;
constexpr int blocklen = 8;
assert(n % qk == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(nb);
UNUSED(ncols_interleaved);
UNUSED(blocklen);
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
constexpr int col_pairs = ncols_interleaved / 2;
const uint8x16_t m4b = vdupq_n_u8(0x0f);
const uint8x16_t mask_lo = vdupq_n_u8(0x03);
const uint8x16_t mask_hi = vdupq_n_u8(0x30);
// 1x8 tile = 2 x 4
float32x4_t acc_f32[2];
const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
acc_f32[0] = vdupq_n_f32(0);
acc_f32[1] = vdupq_n_f32(0);
for (int b = 0; b < nb; b++) {
float32x4_t q6_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d)); // d0 d1 d2 d3
float32x4_t q6_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4)); // d4 d5 d6 d7
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
float32x4_t sb_scale_0 = vmulq_f32(q6_d_0, q8_d);
float32x4_t sb_scale_1 = vmulq_f32(q6_d_1, q8_d);
int32x2_t acc[col_pairs];
for (int i = 0; i < col_pairs; i++) {
acc[i] = vdup_n_s32(0);
}
// Load all 16 scales once and widen to int16 (Q6_K has 16 scales per block)
// Reused for bias and dequantization later
int16_t q6_scales[16 * 8];
for (int i = 0; i < 16; i++) {
int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
vst1q_s16(q6_scales + i * 8, scales);
}
// Compute bias per column using q8 bsums and preloaded scales to skip the -32 shift
int32x4_t bias_lo = vdupq_n_s32(0);
int32x4_t bias_hi = vdupq_n_s32(0);
// Load bsums in chunks of 4 to process with vectorized operations
for (int i = 0; i < 16; i += 4) {
int16x4_t bsums_vec = vld1_s16(q8_ptr[b].bsums + i);
int16x4_t scales_lo_0 = vld1_s16(q6_scales + (i + 0) * 8);
int16x4_t scales_hi_0 = vld1_s16(q6_scales + (i + 0) * 8 + 4);
int16x4_t scales_lo_1 = vld1_s16(q6_scales + (i + 1) * 8);
int16x4_t scales_hi_1 = vld1_s16(q6_scales + (i + 1) * 8 + 4);
int16x4_t scales_lo_2 = vld1_s16(q6_scales + (i + 2) * 8);
int16x4_t scales_hi_2 = vld1_s16(q6_scales + (i + 2) * 8 + 4);
int16x4_t scales_lo_3 = vld1_s16(q6_scales + (i + 3) * 8);
int16x4_t scales_hi_3 = vld1_s16(q6_scales + (i + 3) * 8 + 4);
bias_lo = vmlal_lane_s16(bias_lo, scales_lo_0, bsums_vec, 0);
bias_hi = vmlal_lane_s16(bias_hi, scales_hi_0, bsums_vec, 0);
bias_lo = vmlal_lane_s16(bias_lo, scales_lo_1, bsums_vec, 1);
bias_hi = vmlal_lane_s16(bias_hi, scales_hi_1, bsums_vec, 1);
bias_lo = vmlal_lane_s16(bias_lo, scales_lo_2, bsums_vec, 2);
bias_hi = vmlal_lane_s16(bias_hi, scales_hi_2, bsums_vec, 2);
bias_lo = vmlal_lane_s16(bias_lo, scales_lo_3, bsums_vec, 3);
bias_hi = vmlal_lane_s16(bias_hi, scales_hi_3, bsums_vec, 3);
}
bias_lo = vshlq_n_s32(bias_lo, 5);
bias_hi = vshlq_n_s32(bias_hi, 5);
// Process two 128-value halves per superblock
for (int half = 0; half < 2; half++) {
const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
// A subblock (sb) is a set of weights that share the scale
// Since q6_K scales are per 16 elements
// num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves)
for (int sb = 0; sb < QK_K / 64; sb++) {
const int8_t * q8_base_l = q8_ptr[b].qs + half * 128 + sb * 16;
const int8_t * q8_base_h = q8_base_l + 64;
// Load and duplicate q8 values (each register covers two interleaved columns of q6)
int8x16_t q8_l[2];
int8x16_t q8_h[2];
for (int i = 0; i < 2; i++) {
q8_l[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base_l + i * 8));
q8_h[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base_h + i * 8));
}
// TODO: Test other qh repack patterns to reduce loads
const int ql_off_base = sb * QK_K / 2;
const int qh_off_base = ql_off_base & 255; // wraps after 256 bytes
// Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1)
ggml_uint8x16x4_t q6_ql_0 = ggml_vld1q_u8_x4(ql_base + ql_off_base);
ggml_uint8x16x4_t q6_ql_1 = ggml_vld1q_u8_x4(ql_base + ql_off_base + 64);
ggml_uint8x16x4_t q6_qh_0 = ggml_vld1q_u8_x4(qh_base + qh_off_base);
ggml_uint8x16x4_t q6_qh_1 = ggml_vld1q_u8_x4(qh_base + qh_off_base + 64);
// Adjust qh for subblocks 2 and 3 (shift right by 2)
if (sb > 1) {
q6_qh_0.val[0] = vshrq_n_u8(q6_qh_0.val[0], 2);
q6_qh_0.val[1] = vshrq_n_u8(q6_qh_0.val[1], 2);
q6_qh_0.val[2] = vshrq_n_u8(q6_qh_0.val[2], 2);
q6_qh_0.val[3] = vshrq_n_u8(q6_qh_0.val[3], 2);
q6_qh_1.val[0] = vshrq_n_u8(q6_qh_1.val[0], 2);
q6_qh_1.val[1] = vshrq_n_u8(q6_qh_1.val[1], 2);
q6_qh_1.val[2] = vshrq_n_u8(q6_qh_1.val[2], 2);
q6_qh_1.val[3] = vshrq_n_u8(q6_qh_1.val[3], 2);
}
// Process column pairs (0-1, 2-3, 4-5, 6-7)
for (int cp = 0; cp < col_pairs; cp++) {
const uint8x16_t q6_qs_cp_0_l = q6_ql_0.val[cp];
const uint8x16_t q6_qs_cp_1_l = q6_ql_1.val[cp];
const uint8x16_t q6_qs_cp_0_h = q6_qh_0.val[cp];
const uint8x16_t q6_qs_cp_1_h = q6_qh_1.val[cp];
// Extract high 2 bits for upper nibble reconstruction
const uint8x16_t q6_qs_cp_0_hh = vandq_u8(q6_qs_cp_0_h, mask_hi);
const uint8x16_t q6_qs_cp_1_hh = vandq_u8(q6_qs_cp_1_h, mask_hi);
// q6 = (low4 | high2<<4), without -32 bias (handled via bsums)
const int8x16_t q6_l0 = vreinterpretq_s8_u8(
vsliq_n_u8(vandq_u8(q6_qs_cp_0_l, m4b), vandq_u8(q6_qs_cp_0_h, mask_lo), 4));
const int8x16_t q6_l1 = vreinterpretq_s8_u8(
vsliq_n_u8(vandq_u8(q6_qs_cp_1_l, m4b), vandq_u8(q6_qs_cp_1_h, mask_lo), 4));
const int8x16_t q6_h0 =
vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_0_l, 4), q6_qs_cp_0_hh));
const int8x16_t q6_h1 =
vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_1_l, 4), q6_qs_cp_1_hh));
int32x4_t sb_acc_l = vdupq_n_s32(0);
sb_acc_l = vdotq_s32(sb_acc_l, q6_l0, q8_l[0]);
sb_acc_l = vdotq_s32(sb_acc_l, q6_l1, q8_l[1]);
int32x4_t sb_acc_h = vdupq_n_s32(0);
sb_acc_h = vdotq_s32(sb_acc_h, q6_h0, q8_h[0]);
sb_acc_h = vdotq_s32(sb_acc_h, q6_h1, q8_h[1]);
// Pairwise add to get per-column sums: [col0, col1]
int32x2_t sum_l = vpadd_s32(vget_low_s32(sb_acc_l), vget_high_s32(sb_acc_l));
int32x2_t sum_h = vpadd_s32(vget_low_s32(sb_acc_h), vget_high_s32(sb_acc_h));
const int scale_idx_l = half * 8 + sb;
const int scale_idx_h = half * 8 + sb + 4;
// Access scales using array indexing (scales are interleaved by column)
const int32x2_t scale_vec_l = { (int32_t) q6_scales[scale_idx_l * 8 + cp * 2],
(int32_t) q6_scales[scale_idx_l * 8 + cp * 2 + 1] };
const int32x2_t scale_vec_h = { (int32_t) q6_scales[scale_idx_h * 8 + cp * 2],
(int32_t) q6_scales[scale_idx_h * 8 + cp * 2 + 1] };
// Accumulate scaled results
acc[cp] = vmla_s32(acc[cp], sum_l, scale_vec_l);
acc[cp] = vmla_s32(acc[cp], sum_h, scale_vec_h);
}
}
} // for half
// Bias correction
acc[0] = vsub_s32(acc[0], vget_low_s32(bias_lo));
acc[1] = vsub_s32(acc[1], vget_high_s32(bias_lo));
acc[2] = vsub_s32(acc[2], vget_low_s32(bias_hi));
acc[3] = vsub_s32(acc[3], vget_high_s32(bias_hi));
// Apply superblock scale (no mins for q6_K)
// acc[cp] has [c0, c1]
float32x2_t w_01 = vmul_f32(vcvt_f32_s32(acc[0]), vget_low_f32(sb_scale_0));
float32x2_t w_23 = vmul_f32(vcvt_f32_s32(acc[1]), vget_high_f32(sb_scale_0));
float32x2_t w_45 = vmul_f32(vcvt_f32_s32(acc[2]), vget_low_f32(sb_scale_1));
float32x2_t w_67 = vmul_f32(vcvt_f32_s32(acc[3]), vget_high_f32(sb_scale_1));
acc_f32[0] = vaddq_f32(acc_f32[0], vcombine_f32(w_01, w_23));
acc_f32[1] = vaddq_f32(acc_f32[1], vcombine_f32(w_45, w_67));
} // for b
int base = x * ncols_interleaved;
vst1q_f32(s + base, acc_f32[0]);
vst1q_f32(s + base + 4, acc_f32[1]);
} // for x
return;
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
ggml_gemv_q6_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemv_q8_0_4x4_q8_0(int n,
float * GGML_RESTRICT s,
size_t bs,
@ -2946,16 +3148,17 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
// Scales[i] corresponds to column i
const int scale_offset = cp * 2;
for (int blk = 0; blk < 2; blk++) {
const int32x4_t block_scale = {
(int32_t) q4sb_scales[blk][scale_offset],
(int32_t) q4sb_scales[blk][scale_offset],
(int32_t) q4sb_scales[blk][scale_offset + 1],
(int32_t) q4sb_scales[blk][scale_offset + 1],
};
acc[cp] = vmlaq_s32(acc[cp], sb_acc[blk], block_scale);
acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[blk + 2], block_scale);
}
const int32_t scale_00 = q4sb_scales[0][scale_offset];
const int32_t scale_01 = q4sb_scales[0][scale_offset + 1];
const int32_t scale_10 = q4sb_scales[1][scale_offset];
const int32_t scale_11 = q4sb_scales[1][scale_offset + 1];
const int32x4_t block_scale_0 = vcombine_s32(vdup_n_s32(scale_00), vdup_n_s32(scale_01));
const int32x4_t block_scale_1 = vcombine_s32(vdup_n_s32(scale_10), vdup_n_s32(scale_11));
acc[cp] = vmlaq_s32(acc[cp], sb_acc[0], block_scale_0);
acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[2], block_scale_0);
acc[cp] = vmlaq_s32(acc[cp], sb_acc[1], block_scale_1);
acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[3], block_scale_1);
}
// Multiply Acc bsum + mins
@ -3146,8 +3349,8 @@ void ggml_gemm_q5_K_8x8_q8_K(int n,
const int8x16_t qs_lo_0 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_0, m4b), hbit_lo_0, 4));
int32x4_t acc_0 = sb_acc[0];
acc_0 = vmmlaq_s32(acc_0, qs_lo_0, q8s[0][0]);
int32x4_t acc_2 = sb_acc[2];
acc_2 = vmmlaq_s32(acc_2, qs_lo_0, q8s[1][0]);
int32x4_t acc_2 = sb_acc[2];
acc_2 = vmmlaq_s32(acc_2, qs_lo_0, q8s[1][0]);
const int8x16_t qs_hi_0 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_0, 4), hbit_hi_0));
int32x4_t acc_1 = sb_acc[1];
acc_1 = vmmlaq_s32(acc_1, qs_hi_0, q8s[0][4]);
@ -3271,6 +3474,223 @@ void ggml_gemm_q5_K_8x8_q8_K(int n,
ggml_gemm_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemm_q6_K_8x8_q8_K(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
constexpr int qk = QK_K;
const int nb = n / qk;
constexpr int ncols_interleaved = 8;
constexpr int blocklen = 8;
assert(n % qk == 0);
assert(nr % 4 == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(nb);
UNUSED(ncols_interleaved);
UNUSED(blocklen);
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
constexpr int q8_k_blocklen = 4;
const uint8x16_t m4b = vdupq_n_u8(0x0f);
const uint8x16_t mask_lo = vdupq_n_u8(0x03);
const uint8x16_t mask_hi = vdupq_n_u8(0x30);
const int8x16_t m32s = vdupq_n_s8(32);
// 8 accumulators: 4 q8 rows × 2 col groups (0-3, 4-7)
float32x4_t acc_f32[blocklen];
for (int y = 0; y < nr / q8_k_blocklen; y++) {
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
for (int i = 0; i < blocklen; i++) {
acc_f32[i] = vdupq_n_f32(0);
}
for (int b = 0; b < nb; b++) {
int32x4_t acc[8]; // rows 01 stored in [0][1][2][3], rows 23 stored in [4][5][6][7]
for (int i = 0; i < 8; i++) {
acc[i] = vdupq_n_s32(0);
}
// Q6_K has simple 8-bit scales, 16 per block (one per 16 values)
// Reused for bias and dequantization later
int16_t q6_scales[16 * 8];
for (int i = 0; i < 16; ++i) {
int16x8_t s16 = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
vst1q_s16(q6_scales + i * 8, s16);
}
// Process two 128-value halves per superblock
for (int half = 0; half < 2; half++) {
const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
// A subblock (sb) is a set of weights that share the scale
// Since q6_K scales are per 16 elements
// num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves)
for (int sb = 0; sb < QK_K / 64; sb++) {
// Q6_K weight index increasing by 64 instead of 32 requires
// loading various q8 memory regions
const int8_t * q8_base_l = q8_ptr[b].qs + half * 512 + sb * 64;
const int8_t * q8_base_h = q8_ptr[b].qs + half * 512 + 256 + sb * 64;
int8x16_t q8_l_01[2];
int8x16_t q8_l_23[2];
for (int i = 0; i < 2; i++) {
const int offset = i * 32;
q8_l_01[i] = vld1q_s8(q8_base_l + offset); // 0..7 & 8..15 (r01)
q8_l_23[i] = vld1q_s8(q8_base_l + offset + 16); // 0..7 & 8..15 (r23)
}
int8x16_t q8_h_01[2];
int8x16_t q8_h_23[2];
for (int i = 0; i < 2; i++) {
const int offset = i * 32;
q8_h_01[i] = vld1q_s8(q8_base_h + offset);
q8_h_23[i] = vld1q_s8(q8_base_h + offset + 16);
}
const int ql_off_base = sb * QK_K / 2;
uint8x16_t q6_ql_0[4];
uint8x16_t q6_ql_1[4];
for (int k = 0; k < 4; k++) {
q6_ql_0[k] = vld1q_u8(ql_base + ql_off_base + 16 * k);
q6_ql_1[k] = vld1q_u8(ql_base + ql_off_base + 64 + 16 * k);
}
const int qh_off_base = (sb * QK_K / 2) & 255; // wrap after 256 bytes
uint8x16_t q6_qh_0[4];
uint8x16_t q6_qh_1[4];
for (int k = 0; k < 4; k++) {
q6_qh_0[k] = vld1q_u8(qh_base + qh_off_base + 16 * k);
q6_qh_1[k] = vld1q_u8(qh_base + qh_off_base + 64 + 16 * k);
}
// Adjust for the proper high bits (Sb 2 and 3)
if (sb > 1) {
for (int k = 0; k < 4; k++) {
q6_qh_0[k] = vshrq_n_u8(q6_qh_0[k], 2);
q6_qh_1[k] = vshrq_n_u8(q6_qh_1[k], 2);
}
}
// Process column pairs (0-1, 2-3, 4-5, 6-7)
for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
const uint8x16_t q6_qs_cp_0_l = q6_ql_0[cp];
const uint8x16_t q6_qs_cp_1_l = q6_ql_1[cp];
const uint8x16_t q6_qs_cp_0_h = q6_qh_0[cp];
const uint8x16_t q6_qs_cp_1_h = q6_qh_1[cp];
// Extract high 2 bits for upper nibble reconstruction
const uint8x16_t q6_qs_cp_0_hh = vandq_u8(q6_qs_cp_0_h, mask_hi);
const uint8x16_t q6_qs_cp_1_hh = vandq_u8(q6_qs_cp_1_h, mask_hi);
// q6 = (low4 | high2<<4) - 32
// Use vsliq_n_u8 to combine shift-left-insert in one instruction (like Q5_K)
const int8x16_t q6_l0 = vsubq_s8(
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_cp_0_l, m4b), vandq_u8(q6_qs_cp_0_h, mask_lo), 4)),
m32s);
const int8x16_t q6_l1 = vsubq_s8(
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_cp_1_l, m4b), vandq_u8(q6_qs_cp_1_h, mask_lo), 4)),
m32s);
const int8x16_t q6_h0 = vsubq_s8(
vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_0_l, 4), q6_qs_cp_0_hh)), m32s);
const int8x16_t q6_h1 = vsubq_s8(
vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_1_l, 4), q6_qs_cp_1_hh)), m32s);
// row pair 0, base_l
int32x4_t sb_acc_0l = vmmlaq_s32(vdupq_n_s32(0), q6_l0, q8_l_01[0]);
sb_acc_0l = vmmlaq_s32(sb_acc_0l, q6_l1, q8_l_01[1]);
// row pair 0, base_h
int32x4_t sb_acc_0h = vmmlaq_s32(vdupq_n_s32(0), q6_h0, q8_h_01[0]);
sb_acc_0h = vmmlaq_s32(sb_acc_0h, q6_h1, q8_h_01[1]);
// row pair 1, base_l
int32x4_t sb_acc_1l = vmmlaq_s32(vdupq_n_s32(0), q6_l0, q8_l_23[0]);
sb_acc_1l = vmmlaq_s32(sb_acc_1l, q6_l1, q8_l_23[1]);
// row pair 1, base_h
int32x4_t sb_acc_1h = vmmlaq_s32(vdupq_n_s32(0), q6_h0, q8_h_23[0]);
sb_acc_1h = vmmlaq_s32(sb_acc_1h, q6_h1, q8_h_23[1]);
const int scale_idx_l = half * 8 + sb;
const int scale_idx_h = half * 8 + sb + 4;
const int32x4_t scale_vec_l = {
q6_scales[scale_idx_l * 8 + cp * 2 + 0],
q6_scales[scale_idx_l * 8 + cp * 2 + 0],
q6_scales[scale_idx_l * 8 + cp * 2 + 1],
q6_scales[scale_idx_l * 8 + cp * 2 + 1],
};
const int32x4_t scale_vec_h = {
q6_scales[scale_idx_h * 8 + cp * 2 + 0],
q6_scales[scale_idx_h * 8 + cp * 2 + 0],
q6_scales[scale_idx_h * 8 + cp * 2 + 1],
q6_scales[scale_idx_h * 8 + cp * 2 + 1],
};
acc[cp] = vmlaq_s32(acc[cp], sb_acc_0l, scale_vec_l);
acc[cp] = vmlaq_s32(acc[cp], sb_acc_0h, scale_vec_h);
acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc_1l, scale_vec_l);
acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc_1h, scale_vec_h);
}
}
} // for half
// Reorder i8mm output to match memory layout
for (int i = 0; i < 8; i++) {
int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));
acc[i] = vcombine_s32(aux.val[0], aux.val[1]);
}
int32x4_t reorder_acc[8] = {
vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),
vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),
vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),
vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),
vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),
vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),
vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),
vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),
};
// Apply superblock scale (no mins for q6_K)
for (int i = 0; i < q8_k_blocklen; i++) {
for (int j = 0; j < 2; j++) {
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]);
float32x4_t q6_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q6_ptr[b].d + j * 4)));
const float32x4_t scale = vmulq_f32(q6_d, q8_d);
acc_f32[2 * i + j] =
vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);
}
}
} // for b
// Store results
for (int i = 0; i < q8_k_blocklen; i++) {
int row = y * q8_k_blocklen + i;
for (int j = 0; j < 2; j++) {
int col = x * ncols_interleaved + j * 4;
int offset = row * bs + col;
vst1q_f32(s + offset, acc_f32[2 * i + j]);
}
}
} // for x
} // for y
return;
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
ggml_gemm_q6_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemm_q8_0_4x4_q8_0(int n,
float * GGML_RESTRICT s,
size_t bs,

View File

@ -6,6 +6,9 @@
#include "ggml-impl.h"
#include "simd-mappings.h"
#define GGML_FA_TILE_Q 32
#define GGML_FA_TILE_KV 16
#ifdef __cplusplus
#include <utility>
@ -84,4 +87,9 @@ static std::pair<int64_t, int64_t> get_thread_range(const struct ggml_compute_pa
return {ir0, ir1};
}
struct ggml_fa_tile_config {
static constexpr size_t Q = GGML_FA_TILE_Q;
static constexpr size_t KV = GGML_FA_TILE_KV;
};
#endif

View File

@ -14,6 +14,7 @@
#include "vec.h"
#include "ops.h"
#include "ggml.h"
#include "common.h"
#if defined(_MSC_VER) || defined(__MINGW32__)
#include <malloc.h> // using malloc.h with MSC/MINGW
@ -2866,10 +2867,12 @@ struct ggml_cplan ggml_graph_plan(
} break;
case GGML_OP_FLASH_ATTN_EXT:
{
const int64_t ne10 = node->src[1]->ne[0]; // DK
const int64_t ne20 = node->src[2]->ne[0]; // DV
const int64_t DK = node->src[1]->ne[0];
const int64_t DV = node->src[2]->ne[0];
cur = sizeof(float)*(1*ne10 + 2*ne20)*n_tasks; // 1x head size K + 2x head size V (per thread)
// Tiled flash attention scratch (tile sizes defined in common.h)
// Per-thread: Q_q + KQ + mask + VKQ32 + V32 + padding
cur = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks;
} break;
case GGML_OP_FLASH_ATTN_BACK:
{

View File

@ -1797,10 +1797,27 @@ class tinyBLAS_Q0_AVX {
} \
} \
template<typename T>
struct mma_instr;
template<>
struct mma_instr<ggml_bf16_t> {
static inline void outer_product(acc_t *acc, vec_t a, vec_t b) {
__builtin_mma_xvbf16ger2pp(acc, a, b);
}
};
template<>
struct mma_instr<ggml_fp16_t> {
static inline void outer_product(acc_t *acc, vec_t a, vec_t b) {
__builtin_mma_xvf16ger2pp(acc, a, b);
}
};
template <typename TA, typename TB, typename TC>
class tinyBLAS_BF16_PPC {
class tinyBLAS_HP16_PPC {
public:
tinyBLAS_BF16_PPC(int64_t k,
tinyBLAS_HP16_PPC(int64_t k,
const TA *A, int64_t lda,
const TB *B, int64_t ldb,
TC *C, int64_t ldc,
@ -2118,8 +2135,8 @@ class tinyBLAS_BF16_PPC {
packNormal((A+(ii*lda)+l), lda, 4, 8, (uint8_t*)vec_A);
packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B);
for (int x = 0; x < 4; x++) {
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
__builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
}
}
SAVE_ACC(&acc_0, ii, jj);
@ -2135,8 +2152,8 @@ class tinyBLAS_BF16_PPC {
packNormal((A+(ii*lda)+l), lda, 8, 8, (uint8_t*)vec_A);
packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B);
for (int x = 0; x < 4; x++) {
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
__builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x+4], vec_B[x]);
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
}
}
SAVE_ACC(&acc_0, ii, jj);
@ -2155,10 +2172,10 @@ class tinyBLAS_BF16_PPC {
packNormal(A+(ii*lda)+l, lda, 8, 8, (uint8_t*)vec_A);
packNormal(B+(jj*ldb)+l, ldb, 8, 8, (uint8_t*)vec_B);
for (int x = 0; x < 4; x++) {
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
__builtin_mma_xvbf16ger2pp(&acc_1, (vec_t)vec_A[x], (vec_t)vec_B[x+4]);
__builtin_mma_xvbf16ger2pp(&acc_2, (vec_t)vec_A[x+4], (vec_t)vec_B[x]);
__builtin_mma_xvbf16ger2pp(&acc_3, (vec_t)vec_A[x+4], (vec_t)vec_B[x+4]);
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
mma_instr<TA>::outer_product(&acc_2, vec_A[x+4], vec_B[x]);
mma_instr<TA>::outer_product(&acc_3, vec_A[x+4], vec_B[x+4]);
}
}
@ -2189,7 +2206,7 @@ class tinyBLAS_BF16_PPC {
packNormal(A+(ii*lda)+l, lda, RM, 4, (uint8_t*)vec_A);
packNormal(B+(jj*ldb)+l, ldb, RN, 4, (uint8_t*)vec_B);
for (int x = 0; x<2; x++) {
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
}
}
__builtin_mma_disassemble_acc(vec_C, &acc_0);
@ -2224,8 +2241,8 @@ class tinyBLAS_BF16_PPC {
packNormal(A+(ii*lda)+l, lda, RM, 8, (uint8_t*)vec_A);
packNormal(B+(jj*ldb)+l, ldb, RN, 8, (uint8_t*)vec_B);
for (int x = 0; x<4; x++) {
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
__builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
}
}
__builtin_mma_disassemble_acc(vec_C, &acc_0);
@ -3418,16 +3435,19 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
return tb.matmul(m, n);
}
#elif defined(__MMA__)
if ((k % 8))
return false;
if(Btype == GGML_TYPE_BF16) {
tinyBLAS_BF16_PPC<ggml_bf16_t, ggml_bf16_t, float> tb{ k,
(const ggml_bf16_t *)A, lda,
(const ggml_bf16_t *)B, ldb,
(float *)C, ldc,
params->ith, params->nth};
tb.matmul(m, n);
return true;
if (k % 8) {
return false;
}
if (Btype == GGML_TYPE_BF16) {
tinyBLAS_HP16_PPC<ggml_bf16_t, ggml_bf16_t, float> tb{ k,
(const ggml_bf16_t *)A, lda,
(const ggml_bf16_t *)B, ldb,
(float *)C, ldc,
params->ith, params->nth };
tb.matmul(m, n);
return true;
}
#elif defined(__riscv_zvfbfwma)
#if LMUL == 1
@ -3516,6 +3536,21 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
#endif
return tb.matmul(m, n);
}
#elif defined(__MMA__)
if (k % 8) {
return false;
}
if (Btype == GGML_TYPE_F16) {
tinyBLAS_HP16_PPC<ggml_fp16_t, ggml_fp16_t, float> tb{ k,
(const ggml_fp16_t *)A, lda,
(const ggml_fp16_t *)B, ldb,
(float *)C, ldc,
params->ith, params->nth };
tb.matmul(m, n);
return true;
}
#endif
return false;
}

View File

@ -8164,6 +8164,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
// online softmax / attention
// loop over n_kv and n_head_kv
// ref: https://arxiv.org/pdf/2112.05682.pdf
for (int64_t ic = 0; ic < nek1; ++ic) {
const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
if (mv == -INFINITY) {
@ -8271,6 +8272,280 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
}
}
static void ggml_compute_forward_flash_attn_ext_tiled(
const ggml_compute_params * params,
ggml_tensor * dst,
int ir0, int ir1) {
const ggml_tensor * q = dst->src[0];
const ggml_tensor * k = dst->src[1];
const ggml_tensor * v = dst->src[2];
const ggml_tensor * mask = dst->src[3];
const ggml_tensor * sinks = dst->src[4];
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int64_t DK = nek0;
const int64_t DV = nev0;
const int64_t N = neq1;
GGML_ASSERT(ne0 == DV);
GGML_ASSERT(ne2 == N);
// input tensor rows must be contiguous
GGML_ASSERT(nbq0 == ggml_type_size(q->type));
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
GGML_ASSERT(neq0 == DK);
GGML_ASSERT(nek0 == DK);
GGML_ASSERT(nev0 == DV);
GGML_ASSERT(neq1 == N);
// dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float));
GGML_ASSERT(nb0 <= nb1);
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);
GGML_ASSERT(k->type == v->type);
const ggml_type kv_type = k->type;
const auto * kv_type_traits_cpu = ggml_get_type_traits_cpu(kv_type);
const ggml_from_float_t kv_from_float = kv_type_traits_cpu->from_float;
const ggml_vec_dot_t kv_vec_dot = kv_type_traits_cpu->vec_dot;
const size_t kv_type_size = ggml_type_size(kv_type);
// broadcast factors
const int64_t rk2 = neq2/nek2;
const int64_t rk3 = neq3/nek3;
const int64_t rv2 = neq2/nev2;
const int64_t rv3 = neq3/nev3;
float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
if (logit_softcap != 0) {
scale /= logit_softcap;
}
const uint32_t n_head = neq2;
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
int ith = params->ith;
static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q;
static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
GGML_ASSERT(nek1 % KV_TILE_SZ == 0 && "KV sequence length must be divisible by KV_TILE_SZ");
int ir = ir0;
while (ir < ir1) {
// q indices for the start of this tile
const int iq3 = ir/(neq2*neq1);
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
// Number of valid rows in this tile:
// - limited by tile size (Q_TILE_SZ)
// - limited by chunk boundary (ir1 - ir)
// - limited by head boundary (neq1 - iq1) to avoid crossing into next head
const int tile_rows = MIN(Q_TILE_SZ, MIN((int)(ir1 - ir), (int)(neq1 - iq1)));
GGML_ASSERT(tile_rows > 0);
const uint32_t h = iq2; // head index
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
float S[Q_TILE_SZ];
float M[Q_TILE_SZ];
for (int i = 0 ; i < Q_TILE_SZ; ++i) {
S[i] = 0.;
M[i] = -INFINITY;
}
// Per-thread scratch layout:
// Q_q: Q_TILE_SZ * DK (converted Q tile in KV type)
// KQ: Q_TILE_SZ * KV_TILE_SZ (attention scores in float)
// mask: Q_TILE_SZ * KV_TILE_SZ (mask in float)
// VKQ32: Q_TILE_SZ * DV (FP32 output accumulator)
// V32: KV_TILE_SZ * DV (F32 buffer for V tile - used for f166 conversion)
float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + CACHE_LINE_SIZE_F32);
void * Q_q = base;
float * KQ = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float));
float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ;
float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ;
float * V32 = VKQ32 + Q_TILE_SZ * DV; // F32 buffer for V tile
memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float));
memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
// k indices
const int ik3 = iq3 / rk3;
const int ik2 = iq2 / rk2;
// v indices
const int iv3 = iq3 / rv3;
const int iv2 = iq2 / rv2;
for (int tq = 0; tq < tile_rows; tq++) {
const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
kv_from_float(pq, (char *)Q_q + tq * DK * kv_type_size, DK);
}
// Zero-pad remaining rows
for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
memset((char *)Q_q + tq * DK * kv_type_size, 0, DK * kv_type_size);
}
for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) {
// skip the tile entirely if all the masks are -inf
if (mask) {
bool can_skip = true;
for (int tq = 0; tq < tile_rows; tq++) {
const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]);
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]);
if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) {
can_skip = false;
}
}
}
if (can_skip) {
continue;
}
}
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
const void * q_row = (const char *)Q_q + tq * DK * kv_type_size;
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
const void * k_row = (const char *) k->data + ((ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3);
float s;
kv_vec_dot(DK, &s, 0, k_row, 0, q_row, 0, 1);
KQ[tq * KV_TILE_SZ + tk] = s * scale;
}
}
if (logit_softcap != 0.0f) {
ggml_vec_tanh_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, KQ);
ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, logit_softcap);
}
if (mask) {
ggml_vec_add_f32(tile_rows * KV_TILE_SZ, KQ, KQ, mask32);
}
bool skip[Q_TILE_SZ] = {};
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
float * kq_row = KQ + tq * KV_TILE_SZ;
float tile_max;
ggml_vec_max_f32(KV_TILE_SZ, &tile_max, kq_row);
if (tile_max == -INFINITY) {
skip[tq] = true;
continue;
}
const float Mold = M[tq];
const float Mnew = fmaxf(Mold, tile_max);
if (Mnew > Mold) {
const float ms = expf(Mold - Mnew);
ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
S[tq] *= ms;
}
M[tq] = Mnew;
S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew);
}
// Convert V tile to F32 first (if F16), then do MAD
// On x86, ggml_vec_mad_f16 internall converts F16<->F32 on every load/store, so pre-converting is faster.
// TODO: on ARM, native f16 should be faster
if (kv_type == GGML_TYPE_F16) {
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
const ggml_fp16_t * v_row = (const ggml_fp16_t *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
ggml_fp16_to_fp32_row(v_row, V32 + tk * DV, DV);
}
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
if (skip[tq]) continue;
float * vkq_row = VKQ32 + tq * DV;
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
const float p = KQ[tq * KV_TILE_SZ + tk];
ggml_vec_mad_f32(DV, vkq_row, V32 + tk * DV, p);
}
}
} else {
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
if (skip[tq]) continue;
float * vkq_row = VKQ32 + tq * DV;
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
const float p = KQ[tq * KV_TILE_SZ + tk];
const float * v_row = (const float *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
ggml_vec_mad_f32(DV, vkq_row, v_row, p);
}
}
}
}
// sinks (apply only to valid rows in the tile)
if (sinks) {
const float s = ((float *)((char *) sinks->data))[h];
for (int tq = 0; tq < tile_rows; tq++) {
float ms = 1.0f;
float vs = 1.0f;
if (s > M[tq]) {
ms = expf(M[tq] - s);
ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
} else {
vs = expf(s - M[tq]);
}
S[tq] = S[tq] * ms + vs;
}
}
for (int tq = 0; tq < tile_rows; tq++) {
// V /= S
const float S_inv = S[tq] == 0.0f ? 0.0f : 1.0f / S[tq];
ggml_vec_scale_f32(DV, VKQ32 + tq * DV, S_inv);
// dst indices
const int i1 = iq1 + tq;
const int i2 = iq2;
const int i3 = iq3;
// permute(0, 2, 1, 3)
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32 + tq * DV, nb1);
}
ir += tile_rows;
}
}
static void ggml_compute_forward_flash_attn_ext_f16(
const ggml_compute_params * params,
ggml_tensor * dst) {
@ -8343,6 +8618,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
// The number of elements in each chunk
const int64_t dr = (nr + nchunk - 1) / nchunk;
static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV;
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16);
const bool use_tiled = (q->type == GGML_TYPE_F32 &&
kv_is_f32_or_f16 &&
k->type == v->type &&
nek1 % KV_TILE_SZ == 0 &&
neq1 >= Q_TILE_SZ); // Only use tiled for batch >= tile size
// The first chunk comes from our thread_id, the rest will get auto-assigned.
int current_chunk = ith;
@ -8350,7 +8634,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const int64_t ir0 = dr * current_chunk;
const int64_t ir1 = MIN(ir0 + dr, nr);
ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
if (use_tiled) {
ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1);
} else {
ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
}
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
}

View File

@ -703,6 +703,97 @@ void ggml_gemv_q5_K_8x8_q8_K_generic(int n,
}
}
void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
constexpr int qk = QK_K;
const int nb = n / qk;
const int ncols_interleaved = 8;
const int blocklen = 8;
assert(n % qk == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(bs);
UNUSED(nr);
float sumf[8];
const block_q8_K * a_ptr = (const block_q8_K *) vy;
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);
for (int j = 0; j < ncols_interleaved; j++) {
sumf[j] = 0.0f;
}
for (int l = 0; l < nb; l++) {
for (int k = 0; k < 16; k++) {
// k = 0.. 7 weights 0-63 low, 64-127 high
// k = 8..15 weights 128-191 low, 192-255 high
const int base_l = (k / 8) * 128 + (k % 8) * 8;
const int base_h = base_l + 64;
const int scale_idx_l = base_l / 16;
const int scale_idx_h = base_h / 16;
// Bit shift cycles 0,2,4,6 for each 32-value group within a 128-value half
const int qh_shift_l = ((base_l % 128) / 32) * 2;
const int qh_shift_h = ((base_h % 128) / 32) * 2;
// qh_half: offset to the correct 32-byte half (0 or 32)
const int qh_half_l = (base_l / 128) * 32;
const int qh_half_h = (base_h / 128) * 32;
for (int j = 0; j < ncols_interleaved; j++) {
// Interleaved scales
const int8_t scale_l = b_ptr[l].scales[scale_idx_l * 8 + j];
const int8_t scale_h = b_ptr[l].scales[scale_idx_h * 8 + j];
int sumi_l = 0;
int sumi_h = 0;
for (int i = 0; i < blocklen; i++) {
const int ql_pos = k * 64 + j * 8 + i;
const int l_4 = b_ptr[l].ql[ql_pos] & 0xF;
const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF;
// qh indexing with 8-byte interleaving (like q5_K)
const int qh_byte_l = qh_half_l + ((base_l + i) % 32);
const int qh_chunk_l = qh_byte_l / 8;
const int qh_pos_l = qh_byte_l % 8;
const int qh_offset_l = qh_chunk_l * 64 + j * 8 + qh_pos_l;
const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3;
const int qh_byte_h = qh_half_h + ((base_h + i) % 32);
const int qh_chunk_h = qh_byte_h / 8;
const int qh_pos_h = qh_byte_h % 8;
const int qh_offset_h = qh_chunk_h * 64 + j * 8 + qh_pos_h;
const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3;
const int q_l = ((hi_2_l << 4) | l_4) - 32;
const int q_h = ((hi_2_h << 4) | hi_4) - 32;
const int8_t a_l = a_ptr[l].qs[base_l + i];
const int8_t a_h = a_ptr[l].qs[base_h + i];
sumi_l += q_l * a_l;
sumi_h += q_h * a_h;
}
sumf[j] +=
(sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
}
}
}
for (int j = 0; j < ncols_interleaved; j++) {
s[x * ncols_interleaved + j] = sumf[j];
}
}
}
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0;
const int nb = n / qk;
@ -1133,15 +1224,7 @@ void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
assert (nr % 4 == 0);
assert (nc % ncols_interleaved == 0);
UNUSED(s);
UNUSED(bs);
UNUSED(vx);
UNUSED(vy);
UNUSED(nr);
UNUSED(nc);
UNUSED(nb);
UNUSED(ncols_interleaved);
UNUSED(blocklen);
float sumf[4][8];
float sum_minf[4][8];
@ -1402,6 +1485,111 @@ void ggml_gemm_q5_K_8x8_q8_K_generic(int n,
}
}
void ggml_gemm_q6_K_8x8_q8_K_generic(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
const int qk = QK_K;
const int nb = n / qk;
const int ncols_interleaved = 8;
const int blocklen = 8;
assert(n % qk == 0);
assert(nr % 4 == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(bs);
float sumf[4][8];
for (int y = 0; y < nr / 4; y++) {
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
sumf[m][j] = 0.0f;
}
}
for (int l = 0; l < nb; l++) {
for (int k = 0; k < 16; k++) {
// k = 0.. 7 weights 0-63 low, 64-127 high
// k = 8..15 weights 128-191 low, 192-255 high
const int base_l = (k / 8) * 128 + (k % 8) * 8;
const int base_h = base_l + 64;
const int scale_idx_l = base_l / 16;
const int scale_idx_h = base_h / 16;
// Bit shift cycles 0,2,4,6 for each 32-value group within a 128-value half
const int qh_shift_l = ((base_l % 128) / 32) * 2;
const int qh_shift_h = ((base_h % 128) / 32) * 2;
// qh_half: offset to the correct 32-byte half (0 or 32)
const int qh_half_l = (base_l / 128) * 32;
const int qh_half_h = (base_h / 128) * 32;
// Activation base indices for q8_Kx4 interleaved format
// Layout: 128-value halves (k/8), then 8-value sub-blocks (k%8) with stride 32
const int q8_base = (k / 8) * 512 + (k % 8) * 32;
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
// Interleaved scales
const int8_t scale_l = b_ptr[l].scales[scale_idx_l * 8 + j];
const int8_t scale_h = b_ptr[l].scales[scale_idx_h * 8 + j];
int sumi_l = 0;
int sumi_h = 0;
for (int i = 0; i < blocklen; i++) {
const int ql_pos = k * 64 + j * 8 + i;
const int l_4 = b_ptr[l].ql[ql_pos] & 0xF;
const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF;
const int qh_idx_l = qh_half_l + ((base_l + i) % 32);
const int qh_chunk_l = qh_idx_l / 8;
const int qh_pos_l = qh_idx_l % 8;
const int qh_offset_l = qh_chunk_l * 64 + j * 8 + qh_pos_l;
const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3;
const int qh_idx_h = qh_half_h + ((base_h + i) % 32);
const int qh_chunk_h = qh_idx_h / 8;
const int qh_pos_h = qh_idx_h % 8;
const int qh_offset_h = qh_chunk_h * 64 + j * 8 + qh_pos_h;
const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3;
const int q_l = ((hi_2_l << 4) | l_4) - 32;
const int q_h = ((hi_2_h << 4) | hi_4) - 32;
const int8_t q8_l = a_ptr[l].qs[q8_base + m * 8 + i];
const int8_t q8_h = a_ptr[l].qs[q8_base + m * 8 + i + 256];
sumi_l += q_l * q8_l;
sumi_h += q_h * q8_h;
}
sumf[m][j] += (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) *
a_ptr[l].d[m];
}
}
}
}
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
}
}
}
}
}
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0;
const int nb = n / qk;
@ -1801,8 +1989,7 @@ static block_q2_Kx8 make_block_q2_Kx8(block_q2_K * in, unsigned int blck_size_in
// Every 16 byte is packed such that it contains scales and mins for corresponding sub blocks from Q2_K structure
// For eg - First 16 bytes contains 16 scales and 16 mins - each of first and second sub blocks from different Q2_K structures
for(int i = 0; i < 128; i++){
for (int i = 0; i < 128; i++) {
// Index for selecting which q2k super block
int src1 = (i % 16) / 2;
// Index for selecting scale
@ -1902,6 +2089,52 @@ static block_q5_Kx8 make_block_q5_Kx8(block_q5_K * in, unsigned int blck_size_in
return out;
}
static block_q6_Kx8 make_block_q6_Kx8(block_q6_K * in, unsigned int blck_size_interleave) {
block_q6_Kx8 out;
constexpr int n_blocks = 8; // Kx8
for (int i = 0; i < n_blocks; i++) {
out.d[i] = in[i].d;
}
const int end_ls = QK_K * 4 / blck_size_interleave;
// Interleave Q6_K quants by taking 8 bytes at a time
for (int i = 0; i < end_ls; ++i) {
int src_id = i % n_blocks;
int src_offset = (i / n_blocks) * blck_size_interleave;
int dst_offset = i * blck_size_interleave;
uint64_t elem_ls;
memcpy(&elem_ls, &in[src_id].ql[src_offset], sizeof(uint64_t));
memcpy(&out.ql[dst_offset], &elem_ls, sizeof(uint64_t));
}
// Interleave high bits using same 8-byte pattern as low bits
const int end_hs = end_ls / 2;
for (int i = 0; i < end_hs; ++i) {
int src_id = i % n_blocks;
int src_offset = (i / n_blocks) * blck_size_interleave;
int dst_offset = i * blck_size_interleave;
uint64_t elem_hs;
memcpy(&elem_hs, &in[src_id].qh[src_offset], sizeof(uint64_t));
memcpy(&out.qh[dst_offset], &elem_hs, sizeof(uint64_t));
}
// The below logic is designed so as to unpack and rearrange scales in Q6_K
// The output Q6_Kx8 structure interleaves the 8 bit scales in the same fashion as the quants
// Q6_K structure has an 8-bit scale per 16 elements -> 16 scales
// scales: [0 bl0 0 bl1 ... 0 bl7][1 bl0 ... 1 bl7] ... [15 bl0 ... 15 bl7] (bl = block)
constexpr int n_scales = QK_K / 16;
for (int i = 0; i < n_blocks; i++) {
for (int j = 0; j < n_scales; j++) {
out.scales[j * n_blocks + i] = in[i].scales[j];
}
}
return out;
}
static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
@ -1983,7 +2216,7 @@ static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block
for (int b = 0; b < nrow; b += nrows_interleaved) {
for (int64_t x = 0; x < nblocks; x++) {
for (int i = 0; i < nrows_interleaved; i++ ) {
for (int i = 0; i < nrows_interleaved; i++) {
dst_tmp[i] = src[x + i * nblocks];
}
*dst++ = make_block_q2_Kx8(dst_tmp, interleave_block);
@ -2027,6 +2260,35 @@ static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor * t,
return 0;
}
static int repack_q6_K_to_q6_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
GGML_ASSERT(t->type == GGML_TYPE_Q6_K);
GGML_ASSERT(interleave_block == 8);
constexpr int nrows_interleaved = 8;
block_q6_Kx8 * dst = (block_q6_Kx8 *)t->data;
const block_q6_K * src = (const block_q6_K *) data;
block_q6_K dst_tmp[8];
int nrow = ggml_nrows(t);
int nblocks = t->ne[0] / QK_K;
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q6_K));
if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
return -1;
}
for (int b = 0; b < nrow; b += nrows_interleaved) {
for (int64_t x = 0; x < nblocks; x++) {
for (int i = 0; i < nrows_interleaved; i++) {
dst_tmp[i] = src[x + i * nblocks];
}
*dst++ = make_block_q6_Kx8(dst_tmp, interleave_block);
}
src += nrows_interleaved * nblocks;
}
return 0;
}
static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
GGML_ASSERT(interleave_block == 8);
@ -2249,6 +2511,10 @@ template <> int repack<block_q5_K, 8, 8>(struct ggml_tensor * t, const void * da
return repack_q5_K_to_q5_K_8_bl(t, 8, data, data_size);
}
template <> int repack<block_q6_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
return repack_q6_K_to_q6_K_8_bl(t, 8, data, data_size);
}
template <> int repack<block_iq4_nl, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size);
}
@ -2286,7 +2552,14 @@ template <> void gemv<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t
ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
template <> void gemv<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
template <>
void gemv<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n,
float * s,
size_t bs,
const void * vx,
const void * vy,
int nr,
int nc) {
ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
}
@ -2302,6 +2575,10 @@ template <> void gemv<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t
ggml_gemv_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
}
template <> void gemv<block_q6_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
}
template <> void gemv<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
}
@ -2330,7 +2607,14 @@ template <> void gemm<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t
ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
template <> void gemm<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
template <>
void gemm<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n,
float * s,
size_t bs,
const void * vx,
const void * vy,
int nr,
int nc) {
ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
@ -2350,6 +2634,10 @@ template <> void gemm<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t
ggml_gemm_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
}
template <> void gemm<block_q6_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
}
template <> void gemm<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
}
@ -2714,20 +3002,19 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
for (int ir1 = 0; ir1 < nr1; ir1++) {
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
const int id = row_mapping.i1; // selected expert index
const int id = row_mapping.i1; // selected expert index
const int64_t i11 = id % ne11;
const int64_t i12 = row_mapping.i2; // row index in src1
const int64_t i12 = row_mapping.i2; // row index in src1
const int64_t i1 = id; // selected expert index
const int64_t i2 = i12; // row
const int64_t i1 = id; // selected expert index
const int64_t i2 = i12; // row
const auto * src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
(float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
src0_cur + src0_cur_start * nb01,
src1_col, 1, src0_cur_end - src0_cur_start);
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(
ne00, (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start);
}
}
#undef MMID_MATRIX_ROW
@ -2743,7 +3030,6 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
} // namespace ggml::cpu::repack
static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(const struct ggml_tensor * cur) {
// instance for Q4
static const ggml::cpu::repack::tensor_traits<block_q4_0, 4, 4, GGML_TYPE_Q8_0> q4_0_4x4_q8_0;
static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 4, GGML_TYPE_Q8_0> q4_0_4x8_q8_0;
@ -2756,6 +3042,9 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
// instance for Q5_K
static const ggml::cpu::repack::tensor_traits<block_q5_K, 8, 8, GGML_TYPE_Q8_K> q5_K_8x8_q8_K;
// instance for Q6_K
static const ggml::cpu::repack::tensor_traits<block_q6_K, 8, 8, GGML_TYPE_Q8_K> q6_K_8x8_q8_K;
// instance for Q2
static const ggml::cpu::repack::tensor_traits<block_q2_K, 8, 8, GGML_TYPE_Q8_K> q2_K_8x8_q8_K;
@ -2812,6 +3101,12 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
return &q5_K_8x8_q8_K;
}
}
} else if (cur->type == GGML_TYPE_Q6_K) {
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
if (cur->ne[1] % 8 == 0) {
return &q6_K_8x8_q8_K;
}
}
} else if (cur->type == GGML_TYPE_IQ4_NL) {
if (ggml_cpu_has_avx2()) {
if (cur->ne[1] % 8 == 0) {

View File

@ -65,6 +65,16 @@ struct block_q5_Kx8 {
static_assert(sizeof(block_q5_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 5,
"wrong q5_K block size/padding");
struct block_q6_Kx8 {
ggml_half d[8];
int8_t scales[QK_K / 16 * 8];
uint8_t ql[QK_K / 2 * 8]; // low bits of 6-bit quants (groups of 2)
uint8_t qh[QK_K / 4 * 8]; // high bits of 6-bit quants (groups of 4)
};
static_assert(sizeof(block_q6_Kx8) == sizeof(ggml_half) * 8 + QK_K / 16 * 8 + 3 * QK_K / 4 * 8,
"wrong q6_K block size/padding");
struct block_q8_Kx4 {
float d[4]; // delta
int8_t qs[QK_K * 4]; // quants
@ -95,13 +105,14 @@ void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTR
void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_quantize_mat_q8_K_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@ -111,6 +122,7 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@ -130,6 +142,7 @@ void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@ -139,6 +152,7 @@ void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);

View File

@ -53,6 +53,7 @@
// While BW spans CC 1000, 1100 & 1200, we are integrating Tensor Core instructions available to 1200 family, see
// https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#blackwell-sm120-gemms
#define GGML_CUDA_CC_BLACKWELL 1200
#define GGML_CUDA_CC_DGX_SPARK 1210
#define GGML_CUDA_CC_RUBIN 1300
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
#define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000

View File

@ -629,8 +629,8 @@ static __global__ void flash_attn_mask_to_KV_max(
template<int D, int ncols1, int ncols2> // D == head size
__launch_bounds__(D, 1)
static __global__ void flash_attn_stream_k_fixup(
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11,
const int nbatch_fa) {
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03,
const int ne11, const int ne12, const int nbatch_fa) {
constexpr int ncols = ncols1*ncols2;
const int bidx0 = blockIdx.x;
@ -641,11 +641,14 @@ static __global__ void flash_attn_stream_k_fixup(
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2;
const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
const bool did_not_have_any_data = kbc0 == kbc0_stop;
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
@ -654,15 +657,19 @@ static __global__ void flash_attn_stream_k_fixup(
return;
}
const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
// z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
const int sequence = kbc0 /(iter_k*iter_j*iter_z_gqa*ne12);
const int z_KV = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
const int zt_gqa = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
const int jt = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
if (jt*ncols1 + j >= ne01) {
const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) {
return;
}
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid;
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid;
// Load the partial result that needs a fixup:
float dst_val = 0.0f;
@ -681,7 +688,7 @@ static __global__ void flash_attn_stream_k_fixup(
int bidx = bidx0 - 1;
int kbc_stop = kbc0;
while(true) {
const int kbc = int64_t(bidx)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
if (kbc == kbc_stop) { // Did not have any data.
bidx--;
kbc_stop = kbc;
@ -782,7 +789,7 @@ void launch_fattn(
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const bool V_is_K_view = V->op == GGML_OP_VIEW && V->src[0] == K && V->data == K->data;
const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs));
const ggml_tensor * mask = dst->src[3];
const ggml_tensor * sinks = dst->src[4];
@ -882,8 +889,10 @@ void launch_fattn(
}
}
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
const int gqa_ratio = Q->ne[2] / K->ne[2];
const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2);
const int ntiles_total = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3];
// Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
@ -958,7 +967,7 @@ void launch_fattn(
blocks_num.x = ntiles_x;
blocks_num.y = parallel_blocks;
blocks_num.z = (Q->ne[2]/ncols2)*Q->ne[3];
blocks_num.z = ntiles_z_gqa*K->ne[2]*Q->ne[3];
if (parallel_blocks > 1) {
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
@ -1012,7 +1021,7 @@ void launch_fattn(
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], nbatch_fa);
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], K->ne[2], nbatch_fa);
}
} else if (parallel_blocks > 1) {
const dim3 block_dim_combine(DV, 1, 1);

View File

@ -933,6 +933,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const float logit_softcap,
const uint3 ne01,
const int ne02,
const int gqa_ratio,
const int ne11,
const int stride_Q1,
const int stride_Q2,
@ -940,6 +941,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const int stride_V,
const int stride_mask,
const int jt,
const int zt_gqa,
const int kb0_start,
const int kb0_stop) {
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
@ -1022,7 +1024,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const int j = jc / ncols2;
const int c = jc % ncols2;
if (jt*ncols1 + j < int(ne01.z)) {
if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt_gqa*ncols2 + c < gqa_ratio)) {
#pragma unroll
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
@ -1408,7 +1410,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const int j_dst = jc_dst / ncols2;
const int c_dst = jc_dst % ncols2;
if (!is_fixup && jt*ncols1 + j_dst >= int(ne01.z)) {
if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= int(ne01.z)) || (ncols2 > 1 && zt_gqa*ncols2 + c_dst >= gqa_ratio))) {
continue;
}
@ -1447,7 +1449,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
}
#else
GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dstk_fixup,
scale, slope, logit_softcap, ne01, ne02,
scale, slope, logit_softcap, ne01, ne02, gqa_ratio,
stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,
jt, kb0_start, kb0_stop);
NO_DEVICE_CODE;
@ -1520,12 +1522,13 @@ static __global__ void flash_attn_ext_f16(
const int stride_V = V_is_K_view ? stride_K : nb21 / sizeof(half2);
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1;
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1;
const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2;
// kbc == k block continuous, current index in continuous ijk space.
int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
@ -1536,22 +1539,24 @@ static __global__ void flash_attn_ext_f16(
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
while (kbc < kbc_stop && kb0_stop == iter_k) {
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
// z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12);
const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
const int head0 = zt * ncols2;
const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV);
const half * mask_h = ncols2 == 1 && !mask ? nullptr :
(const half *) (mask + nb33*(sequence % ne33));
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);
const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);
const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;
if (KV_max) {
kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
@ -1561,12 +1566,12 @@ static __global__ void flash_attn_ext_f16(
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
} else {
constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile.
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
}
kbc += iter_k;
@ -1580,22 +1585,24 @@ static __global__ void flash_attn_ext_f16(
return;
}
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
// z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index.
const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12);
const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
const int head0 = zt * ncols2;
const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV);
const half * mask_h = ncols2 == 1 && !mask ? nullptr :
(const half *) (mask + nb33*(sequence % ne33));
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);
const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);
const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;
if (KV_max) {
kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
@ -1605,7 +1612,7 @@ static __global__ void flash_attn_ext_f16(
constexpr bool needs_fixup = false;
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
#else
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
max_bias, m0, m1, n_head_log2, logit_softcap,
@ -1739,3 +1746,5 @@ extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32);

View File

@ -18,9 +18,11 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
}
}
if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 16/ncols2) {
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);
return;
if constexpr (ncols2 <= 16) {
if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 16/ncols2) {
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);
return;
}
}
if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || amd_wmma_available(cc) || Q->ne[1] <= 32/ncols2) {
@ -33,6 +35,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
template <int DKQ, int DV>
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
@ -60,17 +63,38 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int gqa_ratio = Q->ne[2] / K->ne[2];
if (use_gqa_opt && gqa_ratio % 8 == 0) {
// On Volta the GQA optimizations aren't as impactful vs. minimizing wasted compute:
if (cc == GGML_CUDA_CC_VOLTA) {
if (use_gqa_opt && gqa_ratio % 8 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
return;
}
if (use_gqa_opt && gqa_ratio % 4 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst);
return;
}
if (use_gqa_opt && gqa_ratio % 2 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
return;
}
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
return;
}
if (use_gqa_opt && gqa_ratio > 4) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
return;
}
if (use_gqa_opt && gqa_ratio % 4 == 0) {
if (use_gqa_opt && gqa_ratio > 2) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst);
return;
}
if (use_gqa_opt && gqa_ratio % 2 == 0) {
if (use_gqa_opt && gqa_ratio > 1) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
return;
}
@ -79,6 +103,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
}
static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
@ -121,8 +146,46 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int gqa_ratio = Q->ne[2] / K->ne[2];
GGML_ASSERT(gqa_ratio % 4 == 0);
if (gqa_ratio % 16 == 0) {
if (gqa_ratio == 20) { // GLM 4.7 Flash
if (cc >= GGML_CUDA_CC_DGX_SPARK) {
if (Q->ne[1] <= 8) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
break;
}
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
break;
}
if (cc >= GGML_CUDA_CC_BLACKWELL) {
if (Q->ne[1] <= 4 && K->ne[1] >= 65536) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
break;
}
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
break;
}
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
if (Q->ne[1] <= 4) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
break;
}
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
break;
}
if (cc >= GGML_CUDA_CC_TURING) {
if (Q->ne[1] <= 4) {
if (K->ne[1] <= 16384) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
break;
}
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 32>(ctx, dst);
break;
}
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
break;
}
// Volta:
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
} else if (gqa_ratio % 16 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
@ -234,7 +297,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
// The effective batch size for the kernel can be increased by gqa_ratio.
// The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded,
bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
bool gqa_opt_applies = gqa_ratio >= 2 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
for (const ggml_tensor * t : {Q, K, V, mask}) {
if (t == nullptr || ggml_is_quantized(t->type)) {
continue;
@ -247,7 +310,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
}
}
const bool V_is_K_view = V->op == GGML_OP_VIEW && V->src[0] == K && V->data == K->data;
const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs));
const int cc = ggml_cuda_info().devices[device].cc;
@ -268,7 +331,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
if (V->ne[0] != 512) {
return BEST_FATTN_KERNEL_NONE;
}
if (!gqa_opt_applies || gqa_ratio % 4 != 0) {
if (!gqa_opt_applies) {
return BEST_FATTN_KERNEL_NONE;
}
if (!V_is_K_view) {

View File

@ -3080,63 +3080,166 @@ static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
return true;
}
static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops, std::initializer_list<enum ggml_unary_op> unary_ops) {
static bool ggml_cuda_topk_moe_fusion(const struct ggml_cgraph * cgraph, int node_idx, ggml_cuda_topk_moe_args & args) {
args.sigmoid = false;
args.softmax = false;
args.delayed_softmax = false;
args.prob_bias = false;
args.norm = false;
const int n_nodes = cgraph->n_nodes;
ggml_tensor ** nodes = cgraph->nodes;
if (nodes[node_idx]->op == GGML_OP_SOFT_MAX) {
args.softmax = true;
}
if (nodes[node_idx]->op == GGML_OP_UNARY) {
if (ggml_get_unary_op(nodes[node_idx]) != GGML_UNARY_OP_SIGMOID) {
return false;
}
args.sigmoid = true;
}
if (nodes[node_idx]->op == GGML_OP_ARGSORT) {
args.delayed_softmax = true;
}
node_idx++;
if (args.sigmoid || args.softmax) {
// SOFTMAX -> RESHAPE
if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_RESHAPE ||
nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
return false;
}
ggml_tensor * probs_reshaped = nodes[node_idx];
node_idx++;
if (node_idx >= n_nodes) {
return false;
}
// src of bias add is the unreshaped probs (-2 instead of -1)
if (nodes[node_idx]->op == GGML_OP_ADD && nodes[node_idx]->src[0] == nodes[node_idx - 2]) {
args.prob_bias = true;
node_idx++;
}
// RESHAPE/ADD -> ARGSORT
if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_ARGSORT) {
return false;
}
if (args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
return false;
} else if (!args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 2]) {
return false;
}
node_idx++;
// ARGSORT-> VIEW
if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW ||
nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
return false;
}
node_idx++;
if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_GET_ROWS) {
return false;
}
// GET_ROWS
if (nodes[node_idx]->src[0] != probs_reshaped || nodes[node_idx]->src[1] != nodes[node_idx - 1]) {
return false;
}
node_idx++;
} else if (args.delayed_softmax) {
if (node_idx - 2 < 0) {
return false;
}
ggml_tensor * probs_reshaped = nodes[node_idx - 2];
// VIEW->ARGSORT
if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW ||
nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
return false;
}
node_idx++;
// GET_ROWS
if (node_idx >= n_nodes || nodes[node_idx]->src[1] != nodes[node_idx - 1] ||
nodes[node_idx]->src[0] != probs_reshaped) {
return false;
}
node_idx++;
static const std::vector<ggml_op> remaining_ops = { GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
for (const ggml_op op : remaining_ops) {
if (node_idx >= n_nodes || nodes[node_idx]->op != op || nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
return false;
}
node_idx++;
}
}
// At this point we can check for norm + scale. Everything is now at least valid till the norm
if (node_idx >= n_nodes) {
return true;
}
if (nodes[node_idx]->op == GGML_OP_RESHAPE) {
//check RESHAPE->SUM_ROWS->CLAMP->DIV->RESHAPE
static const std::vector<ggml_op> norm_ops = { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP };
args.norm = true;
for (const ggml_op op : norm_ops) {
if (nodes[node_idx]->op == op && nodes[node_idx]->src[0] == nodes[node_idx - 1]) {
node_idx++;
} else {
args.norm = false;
return true;
}
}
// DIV <- CLAMP, RESHAPE
if (nodes[node_idx]->op != GGML_OP_DIV || nodes[node_idx]->src[1] != nodes[node_idx - 1] ||
nodes[node_idx]->src[0] != nodes[node_idx - 3]) {
args.norm = false;
return true;
}
node_idx++;
if (nodes[node_idx]->op != GGML_OP_RESHAPE || nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
args.norm = false;
return true;
}
node_idx++;
}
if (nodes[node_idx]->op == GGML_OP_SCALE && nodes[node_idx]->src[0] == nodes[node_idx - 1]) {
args.scale = true;
}
return true;
}
static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
int node_idx,
std::initializer_list<enum ggml_op> ops,
std::initializer_list<enum ggml_unary_op> unary_ops) {
#ifndef NDEBUG
const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
GGML_ASSERT(unary_ops.size() == num_unary);
#endif
//TODO: remove special case once ggml_can_fuse can handle empty nodes
std::initializer_list<enum ggml_op> topk_moe_ops =
ggml_cuda_topk_moe_ops(/*with_norm*/ false, /*delayed_softmax=*/false);
std::initializer_list<enum ggml_op> topk_moe_ops_with_norm =
ggml_cuda_topk_moe_ops(/*with_norm=*/true, /*delayed_softmax=*/false);
std::initializer_list<enum ggml_op> topk_moe_ops_delayed_softmax =
ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
const auto is_equal = [](const std::initializer_list<enum ggml_op> & list1,
const std::initializer_list<enum ggml_op> & list2) {
return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end());
};
if (is_equal(topk_moe_ops_with_norm, ops) &&
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx + 9];
ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true;
}
}
if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx + 4];
ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true;
}
}
if (is_equal(topk_moe_ops_delayed_softmax, ops) &&
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
ggml_tensor * weights = cgraph->nodes[node_idx + 5];
ggml_tensor * get_rows = cgraph->nodes[node_idx + 2];
ggml_tensor * argsort = cgraph->nodes[node_idx + 0];
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true;
}
}
std::initializer_list<enum ggml_op> mul_mat_bias_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_GLU };
std::initializer_list<enum ggml_op> mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU };
@ -3398,35 +3501,75 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
// start of fusion operations
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
if (!disable_fusion) {
ggml_cuda_topk_moe_args args;
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
ggml_tensor * weights = cgraph->nodes[i + 9];
ggml_tensor * selected_experts = cgraph->nodes[i + 3];
ggml_tensor * clamp = cgraph->nodes[i + 7];
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true,
/*delayed softmax*/ false, clamp);
i += 9;
continue;
}
if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX ||
cgraph->nodes[i]->op == GGML_OP_ARGSORT) {
const bool can_fuse = ggml_cuda_topk_moe_fusion(cgraph, i, args);
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
ggml_tensor * weights = cgraph->nodes[i + 4];
ggml_tensor * selected_experts = cgraph->nodes[i + 3];
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false,
/*delayed softmax*/ false);
i += 4;
continue;
}
std::vector<ggml_op> ops;
if (ggml_cuda_can_fuse(cgraph, i,
ggml_cuda_topk_moe_ops(/*with norm*/ false, /*delayed softmax*/ true), {})) {
ggml_tensor * weights = cgraph->nodes[i + 5];
ggml_tensor * ids = cgraph->nodes[i + 1];
if (can_fuse) {
const ggml_tensor * logits = node->src[0];
ggml_tensor * weights = nullptr;
ggml_tensor * ids = nullptr;
const ggml_tensor * bias = nullptr;
const ggml_tensor * clamp = nullptr;
const ggml_tensor * scale = nullptr;
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, ids, /*with norm*/ false,
/*delayed_softmax*/ true);
i += 5;
continue;
if (!args.delayed_softmax) {
ggml_op gating_op = args.sigmoid ? GGML_OP_UNARY : GGML_OP_SOFT_MAX;
int out_nodes[2]; // nodes which can't be elided
if (args.prob_bias) {
bias = cgraph->nodes[i + 2]->src[1];
ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS });
out_nodes[0] = i + 4;
ids = cgraph->nodes[i + 4];
} else {
ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW,
GGML_OP_GET_ROWS });
out_nodes[0] = i + 3;
ids = cgraph->nodes[i + 3];
}
if (args.norm) {
ops.insert(ops.end(), { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP,
GGML_OP_DIV, GGML_OP_RESHAPE });
clamp = cgraph->nodes[i + ops.size() - 3];
}
if (args.scale) {
ops.insert(ops.end(), { GGML_OP_SCALE });
scale = cgraph->nodes[i + ops.size() - 1];
}
weights = cgraph->nodes[i + ops.size() - 1];
out_nodes[1] = i + ops.size() - 1;
if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
ggml_cuda_should_use_topk_moe(node, logits, weights, ids)) {
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
i += ops.size() - 1;
continue;
}
} else if (!args.norm && !args.prob_bias) {
//special case gpt-oss, no norm, no bias.
ops.insert(ops.end(), { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS,
GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE });
weights = cgraph->nodes[i + 5];
ids = cgraph->nodes[i + 1];
const ggml_tensor * softmax = cgraph->nodes[i + 4];
int out_nodes[2] = { i + 1, i + 5 };
if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids)) {
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
i += ops.size() - 1;
continue;
}
}
}
}
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) {
@ -4876,6 +5019,16 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
if (!initialized) {
// Set CUDA_SCALE_LAUNCH_QUEUES before any CUDA API call to improve multi-GPU pipeline parallelism performance
// PR: https://github.com/ggml-org/llama.cpp/pull/19042
if (getenv("CUDA_SCALE_LAUNCH_QUEUES") == nullptr) {
#ifdef _WIN32
_putenv_s("CUDA_SCALE_LAUNCH_QUEUES", "4x");
#else
setenv("CUDA_SCALE_LAUNCH_QUEUES", "4x", 0); // don't overwrite if already set
#endif // _WIN32
}
ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context;
const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;

View File

@ -333,7 +333,33 @@ namespace ggml_cuda_mma {
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) {
return 4 * (threadIdx.x / 16) + l;
return ne * (threadIdx.x / 16) + l;
} else {
NO_DEVICE_CODE;
return -1;
}
}
#elif defined(AMD_MFMA_AVAILABLE)
static constexpr int ne = I * J / 64;
half2 x[ne] = {{0.0f, 0.0f}};
static constexpr __device__ bool supported() {
if (I == 16 && J == 8) return true;
return false;
}
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 8) {
return threadIdx.x % 16;
} else {
NO_DEVICE_CODE;
return -1;
}
}
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) {
return ne * (threadIdx.x / 16) + l;
} else {
NO_DEVICE_CODE;
return -1;
@ -391,7 +417,22 @@ namespace ggml_cuda_mma {
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
#if defined(AMD_WMMA_AVAILABLE)
static constexpr int ne = I * J / 32;
static constexpr int ne = tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::ne;
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
static constexpr __device__ bool supported() {
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::supported();
}
static __device__ __forceinline__ int get_i(const int l) {
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i(l);
}
static __device__ __forceinline__ int get_j(const int l) {
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_j(l);
}
#elif defined(AMD_MFMA_AVAILABLE)
static constexpr int ne = tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::ne;
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
static constexpr __device__ bool supported() {
@ -945,6 +986,32 @@ namespace ggml_cuda_mma {
#endif // AMPERE_MMA_AVAILABLE
}
template <data_layout dl_ab, data_layout dl_d>
static __device__ __forceinline__ void mma(
tile<16, 16, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<16, 8, float, dl_ab> & B) {
#ifdef AMD_MFMA_AVAILABLE
using floatx4_t = __attribute__((ext_vector_type(4))) float;
floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
#if defined(CDNA3)
using floatx2_t = __attribute__((ext_vector_type(2))) float;
const floatx2_t& a_frag = reinterpret_cast<const floatx2_t&>(A.x[0]);
const floatx2_t& b_frag = reinterpret_cast<const floatx2_t&>(B.x[0]);
acc_frag = __builtin_amdgcn_mfma_f32_16x16x8_xf32(a_frag, b_frag, acc_frag, 0, 0, 0);
#elif defined(CDNA2) || defined(CDNA1)
#pragma unroll
for (int i = 0; i < 2; ++i) {
acc_frag = __builtin_amdgcn_mfma_f32_16x16x4f32(A.x[i], B.x[i], acc_frag, 0, 0, 0);
}
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // defined(CDNA3)
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // AMD_MFMA_AVAILABLE
}
static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D,
const tile<16, 8, int> & A,
const tile<8, 8, int> & B,
@ -1054,6 +1121,13 @@ namespace ggml_cuda_mma {
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // RDNA4
#elif defined(AMD_MFMA_AVAILABLE)
using halfx4_t = __attribute__((ext_vector_type(4))) _Float16;
using floatx4_t = __attribute__((ext_vector_type(4))) float;
floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
const halfx4_t& a_frag = reinterpret_cast<const halfx4_t&>(A.x[0]);
const halfx4_t& b_frag = reinterpret_cast<const halfx4_t&>(B.x[0]);
acc_frag = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_frag, 0, 0, 0);
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
@ -1081,11 +1155,31 @@ namespace ggml_cuda_mma {
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // RDNA4
#endif // defined(RDNA4)
#elif defined(AMD_MFMA_AVAILABLE)
using floatx4_t = __attribute__((ext_vector_type(4))) float;
floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
#if defined(CDNA3) || defined(CDNA2)
using bf16x4_t = __attribute__((ext_vector_type(4))) __bf16;
const bf16x4_t& a_frag = reinterpret_cast<const bf16x4_t&>(A.x[0]);
const bf16x4_t& b_frag = reinterpret_cast<const bf16x4_t&>(B.x[0]);
acc_frag = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_frag, b_frag, acc_frag, 0, 0, 0);
#elif defined(CDNA1)
#pragma unroll
for (int i = 0; i < 2; ++i) {
using bf16x2_t = __attribute__((ext_vector_type(2))) __bf16;
const bf16x2_t& a_frag = reinterpret_cast<const bf16x2_t&>(A.x[i]);
const bf16x2_t& b_frag = reinterpret_cast<const bf16x2_t&>(B.x[i]);
acc_frag = __builtin_amdgcn_mfma_f32_16x16x8bf16(a_frag, b_frag, acc_frag, 0, 0, 0);
}
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // AMPERE_MMA_AVAILABLE
#endif // defined(CDNA3) || defined(CDNA2)
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // defined(AMD_WMMA_AVAILABLE)
}
template <data_layout dl_d, data_layout dl_ab>

View File

@ -2,6 +2,13 @@
#include "mmf.cuh"
#include "mmid.cuh"
static __forceinline__ int mmf_get_rows_per_block(const int cc) {
if (GGML_CUDA_CC_IS_CDNA(cc)) {
return MMF_ROWS_PER_BLOCK_CDNA;
} else {
return MMF_ROWS_PER_BLOCK;
}
}
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
GGML_ASSERT( src1->type == GGML_TYPE_F32);
@ -89,28 +96,32 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
ids_info_ptr = &ids_info;
}
const int device = ggml_cuda_get_device();
const int cc = ggml_cuda_info().devices[device].cc;
const int rows_per_block = mmf_get_rows_per_block(cc);
switch (src0->type) {
case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0->data;
constexpr int vals_per_T = 1;
mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
mul_mat_f_switch_rows_per_block<float>(
rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
} break;
case GGML_TYPE_F16: {
const half2 * src0_d = (const half2 *) src0->data;
constexpr int vals_per_T = 2;
mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
mul_mat_f_switch_rows_per_block<half2>(
rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
} break;
case GGML_TYPE_BF16: {
const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
constexpr int vals_per_T = 2;
mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
mul_mat_f_switch_rows_per_block<nv_bfloat162>(
rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
} break;
@ -140,7 +151,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
return false;
}
}
if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) {
if (src0_ne[1] % mmf_get_rows_per_block(cc) != 0) {
return false;
}
if (GGML_CUDA_CC_IS_CDNA3(cc) && type == GGML_TYPE_BF16) {
return false;
}
@ -153,6 +168,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
} else {
if (GGML_CUDA_CC_IS_RDNA3_0(cc) && src1_ncols > 8) {
return false;
} else if (GGML_CUDA_CC_IS_CDNA2(cc) && (type == GGML_TYPE_F16 || type == GGML_TYPE_BF16)) {
//TODO: truse CDNA2 as CDNA1, tune the perf when CDNA2 is available.
return false;
} else if (GGML_CUDA_CC_IS_CDNA1(cc) && (type == GGML_TYPE_F16 || type == GGML_TYPE_BF16)) {
return false;
} else if (src1_ncols > 16) {
return false;
}
@ -160,11 +180,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
switch (type) {
case GGML_TYPE_F32:
return ampere_mma_available(cc);
return ampere_mma_available(cc) || amd_mfma_available(cc);
case GGML_TYPE_F16:
return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc);
return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc);
case GGML_TYPE_BF16:
return ampere_mma_available(cc) || amd_wmma_available(cc);
return ampere_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc);
default:
return false;
}

View File

@ -7,6 +7,31 @@
using namespace ggml_cuda_mma;
#define MMF_ROWS_PER_BLOCK 32
#define MMF_ROWS_PER_BLOCK_CDNA 64
static __forceinline__ int64_t mmf_get_max_block_size(int cc) {
if (GGML_CUDA_CC_IS_CDNA(cc)) {
return 512;
} else {
return 256;
}
}
static __forceinline__ int mmf_get_padding(int cc) {
if (GGML_CUDA_CC_IS_CDNA(cc)) {
return 2;
} else {
return 4;
}
}
static constexpr __device__ int mmf_get_padding() {
#if defined(AMD_MFMA_AVAILABLE)
return 2;
#else
return 4;
#endif // defined(AMD_MFMA_AVAILABLE)
}
struct mmf_ids_data {
const int32_t * ids_src_compact = nullptr;
@ -29,23 +54,25 @@ static __global__ void mul_mat_f(
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
#if defined(AMD_WMMA_AVAILABLE)
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
constexpr bool is_tf32 = std::is_same_v<T, float>;
constexpr int tile_B_I = is_tf32 ? 8 : 16;
constexpr int tile_C_J = is_tf32 ? 8 : 16;
constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout();
typedef tile<16, 8, T, ab_layout> tile_A;
typedef tile<tile_B_I, 8, T, ab_layout> tile_B;
typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
if constexpr (!(std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
typedef tile<16, 8, T, get_input_data_layout()> tile_A;
typedef tile<16, 8, T, get_input_data_layout()> tile_B;
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
#elif defined(AMD_MFMA_AVAILABLE)
if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else {
typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A;
typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B;
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
#else
#ifdef VOLTA_MMA_AVAILABLE
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
if constexpr (!std::is_same_v<T, half2> || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A;
typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C;
#else
if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
typedef tile<16, 8, T> tile_A;
typedef tile<8, 8, T> tile_B;
typedef tile<16, 8, float> tile_C;
@ -57,7 +84,7 @@ static __global__ void mul_mat_f(
}
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr int tile_k_padded = warp_size + 4;
constexpr int tile_k_padded = warp_size + mmf_get_padding();
constexpr int ntA = rows_per_block / tile_A::I;
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
@ -198,7 +225,7 @@ static __global__ void mul_mat_f(
}
float * buf_iw = (float *) compute_base;
constexpr int kiw = nwarps*rows_per_block + 4;
constexpr int kiw = nwarps*rows_per_block + mmf_get_padding();
if (nwarps > 1) {
__syncthreads();
@ -228,27 +255,34 @@ static __global__ void mul_mat_f(
return;
}
float sum = 0.0f;
static_assert(rows_per_block == warp_size, "need loop/check");
float sum[rows_per_block/warp_size] = {0.0f};
static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size.");
#pragma unroll
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
const int i = i0 + threadIdx.x;
#pragma unroll
for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) {
const int i = i0 + i1*warp_size + threadIdx.x;
sum += buf_iw[j*kiw + i];
sum[i1] += buf_iw[j*kiw + i];
}
}
if constexpr (!has_ids) {
dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
#pragma unroll
for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
dst[j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
}
} else {
const int slot = (j < cols_per_block) ? slot_map[j] : -1;
if (slot >= 0 && (col_base + j) < ncols_dst_total) {
dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum;
#pragma unroll
for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
dst[slot*stride_channel_dst + j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
}
}
}
}
#ifdef VOLTA_MMA_AVAILABLE
}
#endif //VOLTA_MMA_AVAILABLE
#else
GGML_UNUSED_VARS(x, y, ids, dst,
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
@ -256,7 +290,7 @@ static __global__ void mul_mat_f(
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
NO_DEVICE_CODE;
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
}
//This kernel is for larger batch sizes of mul_mat_id
@ -271,23 +305,25 @@ static __global__ void mul_mat_f_ids(
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
const uint3 sis1_fd, const uint3 nch_fd) {
// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
#if defined(AMD_WMMA_AVAILABLE)
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
constexpr bool is_tf32 = std::is_same_v<T, float>;
constexpr int tile_B_I = is_tf32 ? 8 : 16;
constexpr int tile_C_J = is_tf32 ? 8 : 16;
constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout();
typedef tile<16, 8, T, ab_layout> tile_A;
typedef tile<tile_B_I, 8, T, ab_layout> tile_B;
typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
if constexpr (!(std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
typedef tile<16, 8, T, get_input_data_layout()> tile_A;
typedef tile<16, 8, T, get_input_data_layout()> tile_B;
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
#elif defined(AMD_MFMA_AVAILABLE)
if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else {
typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A;
typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B;
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
#else
#ifdef VOLTA_MMA_AVAILABLE
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
if constexpr (!std::is_same_v<T, half2> || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A;
typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C;
#else
if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
typedef tile<16, 8, T> tile_A;
typedef tile<8, 8, T> tile_B;
typedef tile<16, 8, float> tile_C;
@ -300,7 +336,7 @@ static __global__ void mul_mat_f_ids(
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr int tile_k_padded = warp_size + 4;
constexpr int tile_k_padded = warp_size + mmf_get_padding();
constexpr int ntA = rows_per_block / tile_A::I;
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
@ -467,7 +503,7 @@ static __global__ void mul_mat_f_ids(
}
float * buf_iw = (float *) compute_base;
constexpr int kiw = nwarps*rows_per_block + 4;
constexpr int kiw = nwarps*rows_per_block + mmf_get_padding();
if (nwarps > 1) {
__syncthreads();
@ -497,13 +533,16 @@ static __global__ void mul_mat_f_ids(
return;
}
float sum = 0.0f;
static_assert(rows_per_block == warp_size, "need loop/check");
float sum[rows_per_block/warp_size] = {0.0f};
static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size.");
#pragma unroll
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
const int i = i0 + threadIdx.x;
#pragma unroll
for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) {
const int i = i0 + i1*warp_size + threadIdx.x;
sum += buf_iw[j*kiw + i];
sum[i1] += buf_iw[j * kiw + i];
}
}
const int global_j = col_base + j;
@ -513,23 +552,24 @@ static __global__ void mul_mat_f_ids(
const int token = (int) qrm.x;
if (token < ncols_dst_total) {
const int slot = (int) qrm.y;
dst[slot*stride_channel_dst + token*stride_col_dst + row0 + threadIdx.x] = sum;
#pragma unroll
for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
dst[slot * stride_channel_dst + token * stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
}
}
}
}
#ifdef VOLTA_MMA_AVAILABLE
}
#endif // VOLTA_MMA_AVAILABLE
#else
GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
NO_DEVICE_CODE;
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
}
template<typename T, int cols_per_block, int nwarps>
template<typename T, int rows_per_block, int cols_per_block, int nwarps>
static inline void mul_mat_f_switch_ids(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t ncols_dst, const int64_t nchannels_dst,
@ -553,7 +593,7 @@ static inline void mul_mat_f_switch_ids(
const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1);
const uint3 nch_fd = init_fastdiv_values((uint32_t) nchannels_dst);
mul_mat_f_ids<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
mul_mat_f_ids<T, rows_per_block, cols_per_block, nwarps><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
(x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst,
ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
@ -564,19 +604,19 @@ static inline void mul_mat_f_switch_ids(
dim3 block_nums_ids = block_nums;
block_nums_ids.y *= col_tiles;
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
mul_mat_f<T, rows_per_block, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
(x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} else {
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
mul_mat_f<T, rows_per_block, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
(x, y, ids, dst, ncols_x, cols_per_block, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
}
}
template <typename T, int cols_per_block>
template <typename T, int rows_per_block, int cols_per_block>
void mul_mat_f_cuda(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
@ -605,7 +645,7 @@ void mul_mat_f_cuda(
int64_t nwarps_best = 1;
int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2);
int64_t max_block_size = 256;
int64_t max_block_size = mmf_get_max_block_size(cc);
for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
if (niter < niter_best) {
@ -614,10 +654,9 @@ void mul_mat_f_cuda(
}
}
constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4;
const int nbytes_cols_per_block_pad = amd_wmma_available(cc) ? tile_B_16::I : tile_B_8::I;
const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + 4) * 4;
const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + mmf_get_padding(cc)) * 4;
const int nbytes_cols_per_block_pad = (amd_wmma_available(cc) || amd_mfma_available(cc)) ? tile_B_16::I : tile_B_8::I;
const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + mmf_get_padding(cc)) * 4;
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
@ -628,56 +667,56 @@ void mul_mat_f_cuda(
switch (nwarps_best) {
case 1: {
mul_mat_f_switch_ids<T, cols_per_block, 1>(
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 1>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 2: {
mul_mat_f_switch_ids<T, cols_per_block, 2>(
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 2>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 3: {
mul_mat_f_switch_ids<T, cols_per_block, 3>(
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 3>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 4: {
mul_mat_f_switch_ids<T, cols_per_block, 4>(
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 4>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 5: {
mul_mat_f_switch_ids<T, cols_per_block, 5>(
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 5>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 6: {
mul_mat_f_switch_ids<T, cols_per_block, 6>(
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 6>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 7: {
mul_mat_f_switch_ids<T, cols_per_block, 7>(
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 7>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 8: {
mul_mat_f_switch_ids<T, cols_per_block, 8>(
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 8>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
@ -691,7 +730,7 @@ void mul_mat_f_cuda(
GGML_UNUSED_VARS(nchannels_y);
}
template <typename T>
template <typename T, int rows_per_block>
static void mul_mat_f_switch_cols_per_block(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
@ -708,82 +747,82 @@ static void mul_mat_f_switch_cols_per_block(
switch (ncols_case) {
case 1: {
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 2: {
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 3: {
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 4: {
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 5: {
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 6: {
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 7: {
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 8: {
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 9: {
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 10: {
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 11: {
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 12: {
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 13: {
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 14: {
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 15: {
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 16: {
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
mul_mat_f_cuda<T, rows_per_block, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
@ -793,8 +832,36 @@ static void mul_mat_f_switch_cols_per_block(
}
}
#define DECL_MMF_CASE_HELPER(T, ncols_dst) \
template void mul_mat_f_cuda<T, ncols_dst>( \
template <typename T>
static void mul_mat_f_switch_rows_per_block(
const int rows_per_block, const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t stride_col_id, const int stride_row_id,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream, const mmf_ids_data * ids_data) {
switch (rows_per_block) {
case MMF_ROWS_PER_BLOCK: {
mul_mat_f_switch_cols_per_block<T, MMF_ROWS_PER_BLOCK>(
x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case MMF_ROWS_PER_BLOCK_CDNA: {
mul_mat_f_switch_cols_per_block<T, MMF_ROWS_PER_BLOCK_CDNA>(
x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
default:
GGML_ABORT("unsupported rows_per_block: %i", rows_per_block);
}
}
#define DECL_MMF_CASE_HELPER(T, nrows_dst, ncols_dst) \
template void mul_mat_f_cuda<T, nrows_dst, ncols_dst>( \
const T * x, const float * y, const int32_t * ids, float * dst, \
const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
const int64_t stride_col_id, const int64_t stride_row_id, \
@ -803,16 +870,22 @@ static void mul_mat_f_switch_cols_per_block(
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \
cudaStream_t stream, const mmf_ids_data * ids_data);
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
#if !defined(GGML_USE_MUSA)
#define DECL_MMF_CASE_EXTERN(ncols_dst) \
extern DECL_MMF_CASE_HELPER(float, ncols_dst) \
extern DECL_MMF_CASE_HELPER(half2, ncols_dst) \
extern DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \
extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \
extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \
extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst)
#define DECL_MMF_CASE(ncols_dst) \
DECL_MMF_CASE_HELPER(float, ncols_dst) \
DECL_MMF_CASE_HELPER(half2, ncols_dst) \
DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \
DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \
DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \
DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst)
DECL_MMF_CASE_EXTERN(1);
DECL_MMF_CASE_EXTERN(2);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32);

View File

@ -71,7 +71,7 @@ for type_k in TYPES_KV:
f.write(SOURCE_FATTN_VEC.format(type_k=type_k, type_v=type_v))
for ncols in [8, 16, 32, 64]:
for ncols2 in [1, 2, 4, 8, 16]:
for ncols2 in [1, 2, 4, 8, 16, 32]:
if ncols2 > ncols:
continue
ncols1 = ncols // ncols2
@ -83,9 +83,9 @@ for ncols in [8, 16, 32, 64]:
continue
if head_size_kq == 72:
continue
if head_size_kq != 576 and ncols2 == 16:
if head_size_kq != 576 and ncols2 in (16, 32):
continue
if head_size_kq == 576 and ncols2 not in (4, 16):
if head_size_kq == 576 and ncols2 not in (4, 16, 32):
continue
head_size_v = head_size_kq if head_size_kq != 576 else 512
f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v))

View File

@ -5,6 +5,13 @@
#include <cmath>
#include <initializer_list>
// Kernel config struct - passed by value to CUDA kernel
struct topk_moe_config {
bool use_sigmoid;
bool with_norm;
bool delayed_softmax;
};
// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
template <int experts_per_thread, bool use_limit>
__device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
@ -50,6 +57,16 @@ __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const in
}
}
template <int experts_per_thread, bool use_limit>
__device__ void sigmoid_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
const int idx = lane + i * WARP_SIZE;
const bool active = !use_limit || (idx < limit);
vals[i] = active ? 1.f / (1.f + expf(-vals[i])) : -INFINITY;
}
}
/*
This kernel does the following:
1. optionally softmax over the logits per token [n_experts, n_tokens]
@ -59,13 +76,16 @@ __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const in
It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
*/
template <int n_experts, bool with_norm, bool delayed_softmax = false>
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
float * weights,
int32_t * ids,
const int n_rows,
const int n_expert_used,
const float clamp_val) {
template <int n_experts, bool has_bias>
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
float * weights,
int32_t * ids,
float * bias,
const int n_rows,
const int n_expert_used,
const float clamp_val,
const float scale_val,
const topk_moe_config config) {
const int row = blockIdx.x * blockDim.y + threadIdx.y;
if (row >= n_rows) {
return;
@ -79,14 +99,41 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
float wt[experts_per_thread];
// Initialize all slots to -INFINITY
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
wt[i] = -INFINITY;
}
#pragma unroll
for (int i = 0; i < n_experts; i += WARP_SIZE) {
const int expert = i + threadIdx.x;
wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY;
}
if constexpr (!delayed_softmax) {
softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
if (!config.delayed_softmax) {
if (config.use_sigmoid) {
sigmoid_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
} else {
softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
}
}
// selection_wt is only needed when bias is present (selection uses wt + bias)
// when no bias, we use wt directly for both selection and weight values
float selection_wt[has_bias ? experts_per_thread : 1];
if constexpr (has_bias) {
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
selection_wt[i] = -INFINITY;
}
#pragma unroll
for (int i = 0; i < n_experts; i += WARP_SIZE) {
const int expert = i + threadIdx.x;
selection_wt[i / WARP_SIZE] =
(n_experts % WARP_SIZE == 0 || expert < n_experts) ? wt[i / WARP_SIZE] + bias[expert] : -INFINITY;
}
}
//at this point, each thread holds either a portion of the softmax distribution
@ -106,22 +153,56 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
float max_val = wt[0];
int max_expert = threadIdx.x;
#pragma unroll
for (int i = 1; i < experts_per_thread; i++) {
const int expert = threadIdx.x + i * WARP_SIZE;
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
max_val = wt[i];
max_expert = expert;
}
}
if constexpr (has_bias) {
float max_val_s = selection_wt[0];
#pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
if (val > max_val || (val == max_val && expert < max_expert)) {
max_val = val;
max_expert = expert;
for (int i = 1; i < experts_per_thread; i++) {
const int expert = threadIdx.x + i * WARP_SIZE;
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_wt[i] > max_val_s) {
max_val = wt[i];
max_val_s = selection_wt[i];
max_expert = expert;
}
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
const float val_s = __shfl_xor_sync(0xFFFFFFFF, max_val_s, mask, WARP_SIZE);
const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) {
max_val = val;
max_val_s = val_s;
max_expert = expert;
}
}
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
selection_wt[max_expert / WARP_SIZE] = -INFINITY;
}
} else {
#pragma unroll
for (int i = 1; i < experts_per_thread; i++) {
const int expert = threadIdx.x + i * WARP_SIZE;
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
max_val = wt[i];
max_expert = expert;
}
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
if (val > max_val || (val == max_val && expert < max_expert)) {
max_val = val;
max_expert = expert;
}
}
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
wt[max_expert / WARP_SIZE] = -INFINITY;
}
}
@ -130,16 +211,14 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
}
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
wt[max_expert / WARP_SIZE] = -INFINITY;
ids[k] = max_expert;
if constexpr (with_norm) {
if (config.with_norm) {
wt_sum += max_val;
}
}
}
if constexpr (with_norm) {
if (config.with_norm) {
wt_sum = warp_reduce_sum(wt_sum);
wt_sum = max(wt_sum, clamp_val);
const float inv_sum = 1.0f / wt_sum;
@ -149,7 +228,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
}
}
if constexpr (delayed_softmax) {
if (config.delayed_softmax) {
softmax_warp_inplace<experts_per_thread, true>(output_weights, n_expert_used, threadIdx.x);
}
@ -157,25 +236,25 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
for (int i = 0; i < experts_per_thread; i++) {
const int idx = i * WARP_SIZE + threadIdx.x;
if (idx < n_expert_used) {
weights[idx] = output_weights[i];
weights[idx] = output_weights[i] * scale_val;
}
}
if (!with_norm) {
GGML_UNUSED(clamp_val);
}
}
template <bool with_norm, bool delayed_softmax = false>
template<bool has_bias>
static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
const float * logits,
float * weights,
int32_t * ids,
float * bias,
const int n_rows,
const int n_expert,
const int n_expert_used,
const float clamp_val) {
static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization");
const float clamp_val,
const float scale_val,
const topk_moe_config config) {
GGML_ASSERT(!(config.with_norm && config.delayed_softmax) &&
"delayed softmax is not supported with weight normalization");
const int rows_per_block = 4;
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
@ -183,44 +262,48 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
switch (n_expert) {
case 1:
topk_moe_cuda<1, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<1, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 2:
topk_moe_cuda<2, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<2, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 4:
topk_moe_cuda<4, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<4, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 8:
topk_moe_cuda<8, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<8, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 16:
topk_moe_cuda<16, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<16, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 32:
topk_moe_cuda<32, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<32, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 64:
topk_moe_cuda<64, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<64, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 128:
topk_moe_cuda<128, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<128, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 256:
topk_moe_cuda<256, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<256, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 512:
topk_moe_cuda<512, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
topk_moe_cuda<512, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
case 576:
topk_moe_cuda<576, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break;
default:
GGML_ASSERT(false && "fatal error");
@ -228,13 +311,14 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
}
}
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const ggml_tensor * logits,
ggml_tensor * weights,
ggml_tensor * ids,
const bool with_norm,
const bool delayed_softmax,
ggml_tensor * clamp) {
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const ggml_tensor * logits,
ggml_tensor * weights,
ggml_tensor * ids,
const ggml_tensor * clamp,
const ggml_tensor * scale,
const ggml_tensor * bias,
const ggml_cuda_topk_moe_args & args) {
GGML_ASSERT(logits->type == GGML_TYPE_F32);
GGML_ASSERT(weights->type == GGML_TYPE_F32);
GGML_ASSERT(ids->type == GGML_TYPE_I32);
@ -245,107 +329,75 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const float * logits_d = (const float *) logits->data;
float * weights_d = (float *) weights->data;
int32_t * ids_d = (int32_t *) ids->data;
float * bias_d = bias ? (float *) bias->data : nullptr;
float scale_val = scale ? ggml_get_op_params_f32(scale, 0) : 1.0f;
GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
const int n_expert_used = weights->ne[1];
const bool with_norm = clamp != nullptr;
float clamp_val = -INFINITY;
if (with_norm) {
if (clamp) {
clamp_val = ggml_get_op_params_f32(clamp, 0);
}
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, clamp_val);
if (clamp) {
clamp_val = ggml_get_op_params_f32(clamp, 0);
}
topk_moe_config config;
config.use_sigmoid = args.sigmoid;
config.with_norm = with_norm;
config.delayed_softmax = args.delayed_softmax;
if (bias) {
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val,
scale_val, config);
} else {
GGML_ASSERT(clamp == nullptr);
if (delayed_softmax) {
launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
clamp_val);
} else {
launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
clamp_val);
}
launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val,
scale_val, config);
}
}
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op,
const ggml_tensor * weights,
const ggml_tensor * get_rows,
const ggml_tensor * argsort,
const ggml_tensor * clamp,
int n_expert) {
ggml_tensor * probs = get_rows->src[0];
if (probs->op != GGML_OP_RESHAPE) {
return false;
}
probs = probs->src[0];
ggml_tensor * selection_probs = argsort->src[0];
if (probs != selection_probs) {
const ggml_tensor * logits,
const ggml_tensor * ids) {
const int n_expert = ids->nb[1] / ids->nb[0];
if (((n_expert & (n_expert - 1)) != 0 || n_expert > 512) && n_expert != 576) {
return false;
}
float scale = 1.0f;
float max_bias = 0.0f;
memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float));
memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float));
if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
if (!ggml_is_contiguous(weights) || !ggml_is_contiguous(logits)) {
return false;
}
if (scale != 1.0f || max_bias != 0.0f) {
return false;
}
if (gating_op->op == GGML_OP_SOFT_MAX) {
const ggml_tensor * softmax = gating_op;
float scale = 1.0f;
float max_bias = 0.0f;
// don't fuse when masks or sinks are present
if (softmax->src[1] || softmax->src[2]) {
return false;
}
memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float));
memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float));
// n_expert must be a power of 2
if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) {
return false;
}
if (clamp) {
if (clamp->op != GGML_OP_CLAMP) {
if (!ggml_is_contiguous(softmax->src[0])) {
return false;
}
float max_val = ggml_get_op_params_f32(clamp, 1);
if (max_val != INFINITY) {
if (scale != 1.0f || max_bias != 0.0f) {
return false;
}
// don't fuse when masks or sinks are present
if (softmax->src[1] || softmax->src[2]) {
return false;
}
} else if (gating_op->op == GGML_OP_UNARY) {
ggml_unary_op op = ggml_get_unary_op(gating_op);
if (op != GGML_UNARY_OP_SIGMOID) {
return false;
}
}
return true;
}
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) {
static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
GGML_OP_RESHAPE };
static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS };
static std::initializer_list<enum ggml_op> delayed_softmax_ops = { GGML_OP_ARGSORT, GGML_OP_VIEW,
GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
GGML_ASSERT(!norm || !delayed_softmax);
if (delayed_softmax) {
return delayed_softmax_ops;
}
if (norm) {
return norm_ops;
}
return no_norm_ops;
}

View File

@ -3,19 +3,25 @@
#include <initializer_list>
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const ggml_tensor * logits,
ggml_tensor * weights,
ggml_tensor * ids,
const bool with_norm,
const bool delayed_softmax = false,
ggml_tensor * weight_clamp = nullptr);
struct ggml_cuda_topk_moe_args {
bool sigmoid{};
bool softmax{};
bool delayed_softmax{};
bool prob_bias{};
bool norm{};
bool scale{};
};
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const ggml_tensor * logits,
ggml_tensor * weights,
ggml_tensor * ids,
const ggml_tensor * clamp,
const ggml_tensor * scale,
const ggml_tensor * bias,
const ggml_cuda_topk_moe_args & args);
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op,
const ggml_tensor * weights,
const ggml_tensor * get_rows,
const ggml_tensor * argsort,
const ggml_tensor * clamp,
int n_expert);
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);
const ggml_tensor * logits,
const ggml_tensor * ids);

View File

@ -62,6 +62,8 @@ file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
list(APPEND GGML_SOURCES_ROCM ${SRCS})
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
list(APPEND GGML_SOURCES_ROCM ${SRCS})
file(GLOB SRCS "../ggml-cuda/template-instances/mmf*.cu")
list(APPEND GGML_SOURCES_ROCM ${SRCS})
if (GGML_CUDA_FA_ALL_QUANTS)
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*.cu")

View File

@ -785,8 +785,12 @@ ggml_metal_device_t ggml_metal_device_init(void) {
dev->props.op_offload_min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
dev->props.max_buffer_size = dev->mtl_device.maxBufferLength;
dev->props.max_working_set_size = dev->mtl_device.recommendedMaxWorkingSetSize;
dev->props.max_theadgroup_memory_size = dev->mtl_device.maxThreadgroupMemoryLength;
if (@available(macOS 10.12, iOS 16.0, *)) {
dev->props.max_working_set_size = dev->mtl_device.recommendedMaxWorkingSetSize;
} else {
dev->props.max_working_set_size = dev->mtl_device.maxBufferLength;
}
strncpy(dev->props.name, [[dev->mtl_device name] UTF8String], sizeof(dev->props.name) - 1);

View File

@ -85,7 +85,8 @@ set(GGML_OPENCL_KERNELS
mul_mv_q4_0_f32_8x_flat
mul_mv_q4_0_f32_1d_8x_flat
mul_mv_q4_0_f32_1d_16x_flat
mul_mv_q6_k
mul_mv_q6_k_f32
mul_mv_q6_k_f32_flat
mul_mv_q8_0_f32
mul_mv_q8_0_f32_flat
mul_mv_mxfp4_f32

View File

@ -533,8 +533,10 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_mul_mat_q4_0_f32_8x_flat;
cl_kernel kernel_convert_block_q4_0_noshuffle;
cl_kernel kernel_restore_block_q4_0_noshuffle;
cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K;
cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat;
cl_kernel kernel_mul_mv_q6_K_f32;
cl_kernel kernel_mul_mv_q6_K_f32_flat;
cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat;
cl_kernel kernel_mul_mv_q8_0_f32, kernel_mul_mv_q8_0_f32_flat;
cl_kernel kernel_solve_tri_f32;
@ -892,6 +894,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
CL_CHECK((backend_ctx->kernel_restore_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q8_0", &err), err));
CL_CHECK((backend_ctx->kernel_restore_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K", &err), err));
CL_CHECK((backend_ctx->kernel_restore_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_K", &err), err));
GGML_LOG_CONT(".");
}
@ -1114,14 +1118,14 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
GGML_LOG_CONT(".");
}
// mul_mv_q6_k
// mul_mv_q6_k_f32
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "mul_mv_q6_k.cl.h"
#include "mul_mv_q6_k_f32.cl.h"
};
#else
const std::string kernel_src = read_file("mul_mv_q6_k.cl");
const std::string kernel_src = read_file("mul_mv_q6_k_f32.cl");
#endif
backend_ctx->program_mul_mv_q6_K =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
@ -1130,6 +1134,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
GGML_LOG_CONT(".");
}
// mul_mv_q6_k_f32_flat
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "mul_mv_q6_k_f32_flat.cl.h"
};
#else
const std::string kernel_src = read_file("mul_mv_q6_k_f32_flat.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_mul_mv_q6_K_f32_flat = clCreateKernel(prog, "kernel_mul_mv_q6_K_f32_flat", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// mul_mv_q8_0_f32
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@ -2919,6 +2940,50 @@ struct ggml_tensor_extra_cl_q8_0 {
}
};
struct ggml_tensor_extra_cl_q6_K {
// Lower 4 bits of quantized weights.
cl_mem ql = nullptr;
// Upper 2 bits of quantized weights.
cl_mem qh = nullptr;
// Scales for each block.
cl_mem s = nullptr;
// Scales for each super block.
cl_mem d = nullptr;
size_t size_ql = 0;
size_t size_qh = 0;
size_t size_s = 0;
size_t size_d = 0;
~ggml_tensor_extra_cl_q6_K() {
reset();
}
void reset() {
if (ql != nullptr) {
CL_CHECK(clReleaseMemObject(ql));
ql = nullptr;
}
if (qh != nullptr) {
CL_CHECK(clReleaseMemObject(qh));
qh = nullptr;
}
if (s != nullptr) {
CL_CHECK(clReleaseMemObject(s));
s = nullptr;
}
if (d != nullptr) {
CL_CHECK(clReleaseMemObject(d));
d = nullptr;
}
size_ql = 0;
size_qh = 0;
size_s = 0;
size_d = 0;
}
};
//------------------------------------------------------------------------------
// Backend API
//------------------------------------------------------------------------------
@ -3465,6 +3530,12 @@ struct ggml_backend_opencl_buffer_context {
for (ggml_tensor_extra_cl_q8_0 * e : temp_tensor_extras_q8_0_in_use) {
delete e;
}
for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K) {
delete e;
}
for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K_in_use) {
delete e;
}
}
ggml_tensor_extra_cl * ggml_opencl_alloc_temp_tensor_extra() {
@ -3527,6 +3598,21 @@ struct ggml_backend_opencl_buffer_context {
return extra;
}
ggml_tensor_extra_cl_q6_K * ggml_opencl_alloc_temp_tensor_extra_q6_K() {
ggml_tensor_extra_cl_q6_K * extra;
if (temp_tensor_extras_q6_K.empty()) {
extra = new ggml_tensor_extra_cl_q6_K();
} else {
extra = temp_tensor_extras_q6_K.back();
temp_tensor_extras_q6_K.pop_back();
}
temp_tensor_extras_q6_K_in_use.push_back(extra);
extra->reset();
return extra;
}
void reset() {
for (ggml_tensor_extra_cl * e : temp_tensor_extras_in_use) {
temp_tensor_extras.push_back(e);
@ -3547,6 +3633,11 @@ struct ggml_backend_opencl_buffer_context {
temp_tensor_extras_q8_0.push_back(e);
}
temp_tensor_extras_q8_0_in_use.clear();
for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K_in_use) {
temp_tensor_extras_q6_K.push_back(e);
}
temp_tensor_extras_q6_K_in_use.clear();
}
// Pools for extras. Available extras are in `temp_tensor_extras`. Extras
@ -3562,6 +3653,8 @@ struct ggml_backend_opencl_buffer_context {
std::vector<ggml_tensor_extra_cl_mxfp4 *> temp_tensor_extras_mxfp4_in_use;
std::vector<ggml_tensor_extra_cl_q8_0 *> temp_tensor_extras_q8_0;
std::vector<ggml_tensor_extra_cl_q8_0 *> temp_tensor_extras_q8_0_in_use;
std::vector<ggml_tensor_extra_cl_q6_K *> temp_tensor_extras_q6_K;
std::vector<ggml_tensor_extra_cl_q6_K *> temp_tensor_extras_q6_K_in_use;
// The buffer_context is initially created by ggml_backend_buft_alloc_buffer
// before any tensor is initialized (at the beginning of alloc_tensor_range).
@ -4068,6 +4161,92 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
return;
}
if (tensor->type == GGML_TYPE_Q6_K) {
ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra;
GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized");
// Allocate the new extra and create aliases from the original.
ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
ggml_tensor_extra_cl_q6_K * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q6_K();
size_t size_ql = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2;
size_t size_qh = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/4;
size_t size_s = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/16;
size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t);
GGML_ASSERT(size_ql + size_qh + size_s + size_d == ggml_nbytes(tensor) &&
"Incorrect tensor size");
cl_int err;
cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
ggml_nbytes(tensor), NULL, &err);
CL_CHECK(err);
CL_CHECK(clEnqueueWriteBuffer(
queue, data_device, CL_TRUE, 0,
ggml_nbytes(tensor), data, 0, NULL, NULL));
cl_buffer_region region;
// Subbuffer for ql
region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment);
region.size = size_ql;
extra->ql = clCreateSubBuffer(
extra_orig->data_device, CL_MEM_READ_WRITE,
CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
auto previous_origin = region.origin;
// Subbuffer for qh
region.origin = align_to(previous_origin + size_ql, backend_ctx->alignment);
region.size = size_qh;
extra->qh = clCreateSubBuffer(
extra_orig->data_device, CL_MEM_READ_WRITE,
CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
previous_origin = region.origin;
// Subbuffer for scales
region.origin = align_to(previous_origin + size_qh, backend_ctx->alignment);
region.size = size_s;
extra->s = clCreateSubBuffer(
extra_orig->data_device, CL_MEM_READ_WRITE,
CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
previous_origin = region.origin;
// Create subbuffer for d.
region.origin = align_to(previous_origin + size_s, backend_ctx->alignment);
region.size = size_d;
extra->d = clCreateSubBuffer(
extra_orig->data_device, CL_MEM_READ_WRITE,
CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
previous_origin = region.origin;
// Flatten the weights
cl_kernel kernel = backend_ctx->kernel_convert_block_q6_K;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->ql));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->s));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->d));
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
size_t local_work_size[] = {64, 1, 1};
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
CL_CHECK(clWaitForEvents(1, &evt));
CL_CHECK(clReleaseMemObject(data_device));
extra->size_ql = size_ql;
extra->size_qh = size_qh;
extra->size_s = size_s;
extra->size_d = size_d;
tensor->extra = extra;
return;
}
#endif // GGML_OPENCL_SOA_Q
ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra;
@ -4277,6 +4456,34 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
size_t local_work_size[] = {1, 1, 1};
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
global_work_size, local_work_size, 0, NULL, &evt));
CL_CHECK(clWaitForEvents(1, &evt));
CL_CHECK(clEnqueueReadBuffer(
queue, data_device, CL_TRUE, offset,
size, data, 0, NULL, NULL));
CL_CHECK(clReleaseMemObject(data_device));
return;
}
if (tensor->type == GGML_TYPE_Q6_K) {
ggml_tensor_extra_cl_q6_K * extra = (ggml_tensor_extra_cl_q6_K *)tensor->extra;
cl_int err;
cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
ggml_nbytes(tensor), NULL, &err);
CL_CHECK(err);
cl_kernel kernel = backend_ctx->kernel_restore_block_q6_K;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->ql));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device));
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
size_t local_work_size[] = {1, 1, 1};
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
global_work_size, local_work_size, 0, NULL, &evt));
@ -7765,6 +7972,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra;
ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra;
ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra;
ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra;
#endif
const int ne00 = src0 ? src0->ne[0] : 0;
@ -8648,14 +8856,49 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
#ifdef GGML_OPENCL_SOA_Q
kernel = backend_ctx->kernel_mul_mv_q6_K_f32_flat;
if (backend_ctx->gpu_family == INTEL) {
nth0 = 16;
nth1 = 2;
ndst = 4;
} else if (backend_ctx->gpu_family == ADRENO) {
nth0 = 64;
nth1 = 2;
ndst = 4;
} else {
GGML_ASSERT(false && "TODO: Unknown GPU");
}
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q6_K->ql));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q6_K->qh));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q6_K->s));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q6_K->d));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne0));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne1));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &r2));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &r3));
#else
kernel = backend_ctx->kernel_mul_mv_q6_K_f32;
if (backend_ctx->gpu_family == INTEL) {
nth0 = 2;
nth1 = 16;
nth0 = 16;
nth1 = 2;
ndst = 1;
} else if (backend_ctx->gpu_family == ADRENO) {
nth0 = 2;
nth1 = 64;
nth0 = 64;
nth1 = 2;
ndst = 1;
} else {
GGML_ASSERT(false && "TODO: Unknown GPU");
}
@ -8675,6 +8918,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3));
#endif // GGML_OPENCL_SOA_Q
break;
case GGML_TYPE_MXFP4: {
#ifdef GGML_OPENCL_SOA_Q
@ -8777,7 +9021,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
} else if (src0t == GGML_TYPE_Q5_K) {
GGML_ASSERT(false && "not implemented");
} else if (src0t == GGML_TYPE_Q6_K) {
size_t global_work_size[] = {(size_t)(ne01+1)/2*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13};
size_t global_work_size[] = {(size_t)(ne01+ndst*nth1-1)/(ndst*nth1)*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13};
size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);

View File

@ -46,6 +46,16 @@ struct block_q4_0
uint8_t qs[QK4_0 / 2];
};
//------------------------------------------------------------------------------
// block_q6_K
//------------------------------------------------------------------------------
struct block_q6_K {
uint8_t ql[QK_K/2]; // quants, lower 4 bits
uint8_t qh[QK_K/4]; // quants, upper 2 bits
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
half d; // super-block scale
};
//------------------------------------------------------------------------------
// kernel_convert_block_q4_0
// Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA).
@ -263,3 +273,63 @@ kernel void kernel_restore_block_q8_0(
b->qs[i] = q[i];
}
}
//------------------------------------------------------------------------------
// kernel_convert_block_q6_K
// Convert the block_q6_K format to 3 separate arrays (AOS -> SOA).
// This kernel does not deshuffle the bits.
// Each thread processes a super block.
//------------------------------------------------------------------------------
kernel void kernel_convert_block_q6_K(
global struct block_q6_K * src0,
global uchar * dst_ql,
global uchar * dst_qh,
global char * dst_s,
global half * dst_d
) {
global struct block_q6_K * b = (global struct block_q6_K *) src0 + get_global_id(0);
global uchar * ql = (global uchar *) dst_ql + QK_K/2*get_global_id(0);
global uchar * qh = (global uchar *) dst_qh + QK_K/4*get_global_id(0);
global char * s = (global char *) dst_s + QK_K/16*get_global_id(0);
global half * d = (global half *) dst_d + get_global_id(0);
*d = b->d;
for (int i = 0; i < QK_K/2; ++i) {
ql[i] = b->ql[i];
}
for (int i = 0; i < QK_K/4; ++i) {
qh[i] = b->qh[i];
}
for (int i = 0; i < QK_K/16; ++i) {
s[i] = b->scales[i];
}
}
// Restore block_q6_K from flattened arrays.
// Each thread processes a super block.
kernel void kernel_restore_block_q6_K(
global uchar * dst_ql,
global uchar * dst_qh,
global char * dst_s,
global half * dst_d,
global struct block_q6_K * dst
) {
global struct block_q6_K * b = (global struct block_q6_K *) dst + get_global_id(0);
global uchar * ql = (global uchar *) dst_ql + QK_K/2*get_global_id(0);
global uchar * qh = (global uchar *) dst_qh + QK_K/4*get_global_id(0);
global char * s = (global char *) dst_s + QK_K/16*get_global_id(0);
global half * d = (global half *) dst_d + get_global_id(0);
b->d = *d;
for (int i = 0; i < QK_K/2; ++i) {
b->ql[i] = ql[i];
}
for (int i = 0; i < QK_K/4; ++i) {
b->qh[i] = qh[i];
}
for (int i = 0; i < QK_K/16; ++i) {
b->scales[i] = s[i];
}
}

View File

@ -0,0 +1,194 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
//------------------------------------------------------------------------------
// kernel_mul_mv_q6_K_f32_flat
//------------------------------------------------------------------------------
#define Q6_K_MASK1 0x03
#define Q6_K_MASK2 0x0C
#define Q6_K_MASK3 0x30
#define Q6_K_MASK4 0xC0
#define QK_K 256
inline float block_q_6_K_dot_y_flat(
global uchar * blk_ql,
global uchar * blk_qh,
global char * blk_scales,
global half * blk_d,
global float * yy,
int ib,
int ip,
int is,
int l0
) {
int y_offset = 128*ip + l0;
int q_offset_l = 64*ip + l0;
int q_offset_h = 32*ip + l0;
global uchar * q1 = blk_ql + ib*128 + q_offset_l;
global uchar * q2 = q1 + QK_K/8;
global uchar * qh = blk_qh + ib*64 + q_offset_h;
global char * sc = blk_scales + ib*16 + is;
global float * y = yy + ib * QK_K + y_offset;
float dall = blk_d[ib];
float sumf = 0;
float4 sums = {0.f, 0.f, 0.f, 0.f};
sums.s0 += y[0+ 0] * ((float)((q1[0] & 0xF) | ((qh[0] & Q6_K_MASK1) << 4)) - 32.f);
sums.s1 += y[0+32] * ((float)((q2[0] & 0xF) | ((qh[0] & Q6_K_MASK2) << 2)) - 32.f);
sums.s2 += y[0+64] * ((float)((q1[0] >> 4) | ((qh[0] & Q6_K_MASK3) << 0)) - 32.f);
sums.s3 += y[0+96] * ((float)((q2[0] >> 4) | ((qh[0] & Q6_K_MASK4) >> 2)) - 32.f);
sums.s0 += y[1+ 0] * ((float)((q1[1] & 0xF) | ((qh[1] & Q6_K_MASK1) << 4)) - 32.f);
sums.s1 += y[1+32] * ((float)((q2[1] & 0xF) | ((qh[1] & Q6_K_MASK2) << 2)) - 32.f);
sums.s2 += y[1+64] * ((float)((q1[1] >> 4) | ((qh[1] & Q6_K_MASK3) << 0)) - 32.f);
sums.s3 += y[1+96] * ((float)((q2[1] >> 4) | ((qh[1] & Q6_K_MASK4) >> 2)) - 32.f);
sums.s0 += y[2+ 0] * ((float)((q1[2] & 0xF) | ((qh[2] & Q6_K_MASK1) << 4)) - 32.f);
sums.s1 += y[2+32] * ((float)((q2[2] & 0xF) | ((qh[2] & Q6_K_MASK2) << 2)) - 32.f);
sums.s2 += y[2+64] * ((float)((q1[2] >> 4) | ((qh[2] & Q6_K_MASK3) << 0)) - 32.f);
sums.s3 += y[2+96] * ((float)((q2[2] >> 4) | ((qh[2] & Q6_K_MASK4) >> 2)) - 32.f);
sums.s0 += y[3+ 0] * ((float)((q1[3] & 0xF) | ((qh[3] & Q6_K_MASK1) << 4)) - 32.f);
sums.s1 += y[3+32] * ((float)((q2[3] & 0xF) | ((qh[3] & Q6_K_MASK2) << 2)) - 32.f);
sums.s2 += y[3+64] * ((float)((q1[3] >> 4) | ((qh[3] & Q6_K_MASK3) << 0)) - 32.f);
sums.s3 += y[3+96] * ((float)((q2[3] >> 4) | ((qh[3] & Q6_K_MASK4) >> 2)) - 32.f);
sumf += dall * (sums.s0 * sc[0] + sums.s1 * sc[2] + sums.s2 * sc[4] + sums.s3 * sc[6]);
return sumf;
}
#undef N_DST
#undef N_SIMDGROUP
#undef N_SIMDWIDTH
#ifdef INTEL_GPU
#define N_DST 4
#define N_SIMDGROUP 2
#define N_SIMDWIDTH 16
#elif defined (ADRENO_GPU)
#define N_DST 4
#define N_SIMDGROUP 2
#define N_SIMDWIDTH 64
#endif
#define BLOCK_STRIDE (N_SIMDWIDTH/16) // number of blocks each subgroup processes
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_16
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_mul_mv_q6_K_f32_flat(
global uchar * src0_ql,
global uchar * src0_qh,
global char * src0_s,
global half * src0_d,
global float * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne10,
int ne12,
int ne0,
int ne1,
int r2,
int r3
) {
src1 = (global float*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);
int nb = ne00/QK_K;
int r0 = get_group_id(0);
int r1 = get_group_id(1);
int im = get_group_id(2);
int i12 = im%ne12;
int i13 = im/ne12;
int first_row = (N_SIMDGROUP * r0 + get_sub_group_id()) * N_DST;
ulong offset_src0 = first_row*nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
ulong offset_src0_ql = offset_src0 * 128;
ulong offset_src0_qh = offset_src0 * 64;
ulong offset_src0_s = offset_src0 * 16;
ulong offset_src0_d = offset_src0;
global uchar * blk_ql = (global uchar *) src0_ql + offset_src0_ql;
global uchar * blk_qh = (global uchar *) src0_qh + offset_src0_qh;
global char * blk_scales = (global char *) src0_s + offset_src0_s;
global half * blk_d = (global half *) src0_d + offset_src0_d;
global float * yy = (global float *) src1 + r1*ne10 + im*ne00*ne1;
int tid = get_sub_group_local_id()/BLOCK_STRIDE; // first block_stride groups have tid=0
int ix = get_sub_group_local_id()%BLOCK_STRIDE; // first block is 0..block_stride-1
int ip = tid/8; // first or second half of (super) block (0 or 1)
int il = tid%8; // each half has 8 parts, one per scale
int n = 4; // 4 scales at a time (and 4 sums)
int l0 = n*il; // offset into half-block, 0..28
int is = 8*ip + l0/16; // 0, 1, 8, 9
float4 sumf = 0;
for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) {
if (first_row + 0 < ne01) {
sumf.s0 += block_q_6_K_dot_y_flat(blk_ql + 0*nb*128, blk_qh + 0*nb*64, blk_scales + 0*nb*16, blk_d + 0*nb, yy, ib, ip, is, l0);
}
if (first_row + 1 < ne01) {
sumf.s1 += block_q_6_K_dot_y_flat(blk_ql + 1*nb*128, blk_qh + 1*nb*64, blk_scales + 1*nb*16, blk_d + 1*nb, yy, ib, ip, is, l0);
}
if (first_row + 2 < ne01) {
sumf.s2 += block_q_6_K_dot_y_flat(blk_ql + 2*nb*128, blk_qh + 2*nb*64, blk_scales + 2*nb*16, blk_d + 2*nb, yy, ib, ip, is, l0);
}
if (first_row + 3 < ne01) {
sumf.s3 += block_q_6_K_dot_y_flat(blk_ql + 3*nb*128, blk_qh + 3*nb*64, blk_scales + 3*nb*16, blk_d + 3*nb, yy, ib, ip, is, l0);
}
}
float4 tot = (float4)(
sub_group_reduce_add(sumf.s0),
sub_group_reduce_add(sumf.s1),
sub_group_reduce_add(sumf.s2),
sub_group_reduce_add(sumf.s3)
);
if (get_sub_group_local_id() == 0) {
if (first_row + 0 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
}
if (first_row + 1 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
}
if (first_row + 2 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
}
if (first_row + 3 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
}
}
}

View File

@ -15,7 +15,6 @@
#include <sycl/sycl.hpp>
#include <sycl/half_type.hpp>
#include <syclcompat/math.hpp>
#include <map>
#ifdef GGML_SYCL_USE_INTEL_ONEMKL

View File

@ -4606,14 +4606,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
#endif
case GGML_OP_NORM:
return true;
case GGML_OP_L2_NORM:
case GGML_OP_GROUP_NORM:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_RMS_NORM:
return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
return true;
case GGML_OP_RMS_NORM_BACK:
return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
return ggml_is_contiguous(op->src[0]);
case GGML_OP_SCALE:
return true;
case GGML_OP_CONT:

View File

@ -251,7 +251,6 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
const float eps, queue_ptr stream, int device) {
const sycl::range<3> global_dims(nsamples, nchannels, nrows);
GGML_ASSERT(ncols % WARP_SIZE == 0);
if (ncols < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
stream->submit([&](sycl::handler& cgh) {
@ -334,7 +333,6 @@ static void group_norm_f32_sycl(const float* x, float* dst,
static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, queue_ptr stream, int device) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
const sycl::range<3> global_dims(nsamples, nchannels, nrows);
@ -374,7 +372,6 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
const int nrows, const float eps,
queue_ptr stream, int device) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
if (ncols < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE);

View File

@ -0,0 +1,70 @@
cmake_minimum_required(VERSION 3.19)
cmake_policy(SET CMP0114 NEW)
include(ExternalProject)
message(STATUS "Including the VirtGPU/Virglrenderer API Remoting")
# Download venus_hw.h from virglrenderer repository
ExternalProject_Add(
venus_hw_header
URL https://gitlab.freedesktop.org/virgl/virglrenderer/-/raw/virglrenderer-1.2.0/src/venus_hw.h
DOWNLOAD_NO_EXTRACT YES
DOWNLOAD_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include
DOWNLOAD_NAME venus_hw.h
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
LOG_DOWNLOAD ON
)
if (NOT GGML_VIRTGPU_BACKEND STREQUAL "ONLY")
message(STATUS "Enable the VirtGPU/Virglrenderer API Remoting frontend library")
find_package(PkgConfig REQUIRED)
pkg_check_modules(DRM REQUIRED libdrm)
if (NOT GGML_BACKEND_DL)
# cannot simply use USE_VIRTGPU, as in the 'else()' case the
# frontend isn't compiled
target_compile_definitions(ggml PUBLIC "GGML_USE_VIRTGPU_FRONTEND")
endif()
ggml_add_backend_library(ggml-virtgpu
ggml-backend-buffer.cpp
ggml-backend.cpp
ggml-backend-device.cpp
ggml-backend-reg.cpp
ggml-backend-buffer-type.cpp
virtgpu-apir.h
virtgpu-forward.gen.h
virtgpu.cpp
virtgpu-shm.cpp
virtgpu-utils.cpp
virtgpu-forward-device.cpp
virtgpu-forward-buffer-type.cpp
virtgpu-forward-buffer.cpp
virtgpu-forward-backend.cpp
virtgpu-forward-impl.h
apir_cs_ggml-rpc-front.cpp
../../include/ggml-virtgpu.h)
target_include_directories(ggml-virtgpu PUBLIC /usr/include/libdrm/)
target_link_libraries(ggml-virtgpu PUBLIC ${DRM_LIBRARIES})
target_include_directories(ggml-virtgpu PUBLIC ${DRM_INCLUDE_DIRS})
target_compile_options(ggml-virtgpu PUBLIC ${DRM_CFLAGS_OTHER})
target_include_directories(ggml-virtgpu PUBLIC ./include)
target_include_directories(ggml-virtgpu PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
# Ensure venus_hw.h is downloaded before building ggml-virtgpu
add_dependencies(ggml-virtgpu venus_hw_header)
target_compile_options(ggml-virtgpu PRIVATE -std=c++20)
else()
message(STATUS "Not building the VirtGPU/Virglrenderer API Remoting frontend library")
endif()
if (NOT GGML_VIRTGPU_BACKEND STREQUAL "OFF")
add_subdirectory("backend")
endif()

View File

@ -0,0 +1,87 @@
#include "backend/shared/apir_cs_rpc.h"
#include "ggml-backend-impl.h"
#include "ggml-impl.h"
#include "ggml-remoting.h"
#include <cinttypes>
#include <unordered_map>
#include <unordered_set>
#include <vector>
apir_rpc_tensor apir_serialize_tensor(const ggml_tensor * tensor) {
apir_rpc_tensor result;
result.id = reinterpret_cast<uint64_t>(tensor);
result.type = tensor->type;
if (tensor->buffer) {
ggml_backend_buffer_t buffer = tensor->buffer;
result.buffer = BUFFER_TO_HOST_HANDLE(buffer);
} else {
result.buffer = 0;
}
for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
result.ne[i] = tensor->ne[i];
result.nb[i] = tensor->nb[i];
}
result.op = tensor->op;
for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
result.op_params[i] = tensor->op_params[i];
}
result.flags = tensor->flags;
for (uint32_t i = 0; i < GGML_MAX_SRC; i++) {
result.src[i] = reinterpret_cast<uint64_t>(tensor->src[i]);
}
result.view_src = reinterpret_cast<uint64_t>(tensor->view_src);
result.view_offs = tensor->view_offs;
result.data = reinterpret_cast<uint64_t>(tensor->data);
if (tensor->data) {
if (!tensor->buffer) {
GGML_ABORT("tensor has data but not buffer");
}
// tensor->data is serialized as an offset to the buffer base address
result.data -= reinterpret_cast<uint64_t>(BUFFER_TO_GGML_CONTEXT(tensor->buffer)->base);
}
snprintf(result.name, GGML_MAX_NAME, "%s", tensor->name);
return result;
}
void apir_add_tensor(ggml_tensor * tensor,
std::vector<apir_rpc_tensor> & tensors,
std::unordered_set<ggml_tensor *> & visited) {
if (tensor == nullptr) {
return;
}
if (visited.find(tensor) != visited.end()) {
return;
}
visited.insert(tensor);
for (int i = 0; i < GGML_MAX_SRC; i++) {
apir_add_tensor(tensor->src[i], tensors, visited);
}
apir_add_tensor(tensor->view_src, tensors, visited);
tensors.push_back(apir_serialize_tensor(tensor));
}
void apir_serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {
uint32_t n_nodes = cgraph->n_nodes;
std::vector<apir_rpc_tensor> tensors;
std::unordered_set<ggml_tensor *> visited;
for (uint32_t i = 0; i < n_nodes; i++) {
apir_add_tensor(cgraph->nodes[i], tensors, visited);
}
// serialization format:
// | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(apir_rpc_tensor)) |
uint32_t n_tensors = tensors.size();
int output_size =
sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(apir_rpc_tensor);
output.resize(output_size, 0);
memcpy(output.data(), &n_nodes, sizeof(n_nodes));
for (uint32_t i = 0; i < n_nodes; i++) {
memcpy(output.data() + sizeof(n_nodes) + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));
}
uint32_t * out_ntensors = (uint32_t *) (output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t));
*out_ntensors = n_tensors;
apir_rpc_tensor * out_tensors =
(apir_rpc_tensor *) (output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t));
memcpy(out_tensors, tensors.data(), n_tensors * sizeof(apir_rpc_tensor));
}

View File

@ -0,0 +1,21 @@
cmake_minimum_required(VERSION 3.19)
cmake_policy(SET CMP0114 NEW)
message(STATUS "Enable the VirtGPU/Virglrenderer backend library")
ggml_add_backend_library(ggml-virtgpu-backend
backend.cpp
backend-dispatched.cpp
backend-dispatched-backend.cpp
backend-dispatched-device.cpp
backend-dispatched-buffer.cpp
backend-dispatched-buffer-type.cpp
shared/api_remoting.h
shared/apir_backend.h
shared/apir_cs.h
apir_cs_ggml-rpc-back.cpp)
target_compile_options(ggml-virtgpu-backend PRIVATE -std=c++20)
# Add include directory for ggml-backend-impl.h and other core headers
target_include_directories(ggml-virtgpu-backend PRIVATE ../..)

View File

@ -0,0 +1,115 @@
#include "ggml-backend-impl.h"
#include "ggml-impl.h"
#include "shared/apir_cs_rpc.h"
#include <cinttypes>
#include <unordered_map>
#include <unordered_set>
#include <vector>
std::unordered_set<ggml_backend_buffer_t> backend_buffers;
void apir_track_backend_buffer(ggml_backend_buffer_t buffer) {
backend_buffers.insert(buffer);
}
bool apir_untrack_backend_buffer(ggml_backend_buffer_t buffer) {
auto it = backend_buffers.find(buffer);
if (it == backend_buffers.end()) {
return false;
}
backend_buffers.erase(it);
return true;
}
std::unordered_set<ggml_backend_buffer_t> apir_get_track_backend_buffers() {
return backend_buffers;
}
ggml_tensor * apir_deserialize_tensor(ggml_context * ctx, const apir_rpc_tensor * tensor) {
ggml_tensor * result =
ggml_new_tensor_4d(ctx, (ggml_type) tensor->type, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
result->nb[i] = tensor->nb[i];
}
result->buffer = reinterpret_cast<ggml_backend_buffer_t>(tensor->buffer);
if (result->buffer && backend_buffers.find(result->buffer) == backend_buffers.end()) {
printf("WARNING: HOST BUFFER NOT FOUND | %p\n", (void *) result->buffer);
result->buffer = nullptr;
}
uint64_t tensor_data = tensor->data;
if (result->buffer) {
// require that the tensor data does not go beyond the buffer end
uint64_t tensor_size = (uint64_t) ggml_nbytes(result);
uint64_t buffer_start = (uint64_t) ggml_backend_buffer_get_base(result->buffer);
uint64_t buffer_size = (uint64_t) ggml_backend_buffer_get_size(result->buffer);
// tensor->data is serialized as an offset to the buffer base address
tensor_data += buffer_start;
GGML_ASSERT(tensor_data + tensor_size >= tensor_data); // check for overflow
GGML_ASSERT(tensor_data >= buffer_start && tensor_data + tensor_size <= buffer_start + buffer_size);
}
result->op = (ggml_op) tensor->op;
for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
result->op_params[i] = tensor->op_params[i];
}
result->flags = tensor->flags;
result->data = reinterpret_cast<void *>(tensor_data);
ggml_set_name(result, tensor->name);
return result;
}
ggml_tensor * apir_create_node(uint64_t id,
ggml_context * ctx,
const std::unordered_map<uint64_t, const apir_rpc_tensor *> & tensor_ptrs,
std::unordered_map<uint64_t, ggml_tensor *> & tensor_map) {
if (id == 0) {
return nullptr;
}
if (tensor_map.find(id) != tensor_map.end()) {
return tensor_map[id];
}
const apir_rpc_tensor * tensor = tensor_ptrs.at(id);
ggml_tensor * result = apir_deserialize_tensor(ctx, tensor);
if (result == nullptr) {
return nullptr;
}
tensor_map[id] = result;
for (int i = 0; i < GGML_MAX_SRC; i++) {
result->src[i] = apir_create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);
}
result->view_src = apir_create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map);
result->view_offs = tensor->view_offs;
return result;
}
ggml_cgraph * apir_deserialize_graph(uint32_t n_nodes,
uint32_t n_tensors,
const apir_rpc_tensor * tensors,
const uint64_t * nodes) {
size_t buf_size = ggml_tensor_overhead() * (n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
ggml_init_params params = {
/*.mem_size =*/buf_size,
/*.mem_buffer =*/NULL,
/*.no_alloc =*/true,
};
ggml_context * ctx = ggml_init(params);
ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false);
graph->n_nodes = n_nodes;
std::unordered_map<uint64_t, const apir_rpc_tensor *> tensor_ptrs;
for (uint32_t i = 0; i < n_tensors; i++) {
tensor_ptrs[tensors[i].id] = &tensors[i];
}
std::unordered_map<uint64_t, ggml_tensor *> tensor_map;
for (uint32_t i = 0; i < n_nodes; i++) {
int64_t id;
memcpy(&id, &nodes[i], sizeof(id));
graph->nodes[i] = apir_create_node(id, ctx, tensor_ptrs, tensor_map);
}
return graph;
}

View File

@ -0,0 +1,13 @@
#include "shared/apir_backend.h"
#define BUFFER_TO_HOST_HANDLE(name) ggml_buffer_to_apir_handle(name)
static inline apir_buffer_host_handle_t ggml_buffer_to_apir_handle(ggml_backend_buffer_t buffer) {
// in the backend, the buffer handle is the buffer pointer
return (apir_buffer_host_handle_t) buffer;
}
static inline apir_buffer_type_host_handle_t ggml_buffer_type_to_apir_handle(ggml_backend_buffer_type_t buft) {
// in the backend, the buffer handle is the buffer pointer
return (apir_buffer_type_host_handle_t) buft;
}

View File

@ -0,0 +1,65 @@
#include "backend-dispatched.h"
#include "backend-virgl-apir.h"
#include "ggml-backend-impl.h"
#include "ggml-backend.h"
#include "ggml-impl.h"
#include "shared/apir_backend.h"
#include <cstdint>
uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
GGML_UNUSED(ctx);
GGML_UNUSED(enc);
static bool async_backend_initialized = false;
static bool async_backend;
if (!async_backend_initialized) {
ggml_backend_dev_props props;
dev->iface.get_props(dev, &props);
async_backend = props.caps.async;
async_backend_initialized = true;
}
uint32_t shmem_res_id;
apir_decode_virtgpu_shmem_res_id(dec, &shmem_res_id);
const void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id);
if (!shmem_data) {
GGML_LOG_ERROR("Couldn't get the shmem addr from virgl\n");
apir_decoder_set_fatal(dec);
return 1;
}
size_t cgraph_size;
apir_decode_size_t(dec, &cgraph_size);
apir_decoder secondary_dec = apir_new_decoder((const char *) shmem_data, cgraph_size);
ggml_cgraph * cgraph = apir_decode_ggml_cgraph(&secondary_dec, cgraph_size);
ggml_status status;
#if APIR_BACKEND_CHECK_SUPPORTS_OP == 1
for (int idx = 0; idx < cgraph->n_nodes; idx++) {
ggml_tensor * op = ggml_graph_node(cgraph, idx);
if (dev->iface.supports_op(dev, op)) {
continue;
}
GGML_LOG_ERROR("Graph node %d (%s) not supported by the backend\n", idx, ggml_op_desc(op));
status = GGML_STATUS_ABORTED;
apir_encode_ggml_status(enc, &status);
return 0;
}
#endif
status = bck->iface.graph_compute(bck, cgraph);
if (async_backend) {
bck->iface.synchronize(bck);
}
apir_encode_ggml_status(enc, &status);
return 0;
}

View File

@ -0,0 +1,89 @@
#include "backend-dispatched.h"
#include "backend-virgl-apir.h"
#include "ggml-backend-impl.h"
#include "ggml-backend.h"
#include "ggml-impl.h"
#include <cstdint>
uint32_t backend_buffer_type_get_name(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
GGML_UNUSED(ctx);
ggml_backend_buffer_type_t buft;
buft = apir_decode_ggml_buffer_type(dec);
const char * string = buft->iface.get_name(buft);
const size_t string_size = strlen(string) + 1;
apir_encode_array_size(enc, string_size);
apir_encode_char_array(enc, string, string_size);
return 0;
}
uint32_t backend_buffer_type_get_alignment(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
GGML_UNUSED(ctx);
ggml_backend_buffer_type_t buft;
buft = apir_decode_ggml_buffer_type(dec);
size_t value = buft->iface.get_alignment(buft);
apir_encode_size_t(enc, &value);
return 0;
}
uint32_t backend_buffer_type_get_max_size(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
GGML_UNUSED(ctx);
ggml_backend_buffer_type_t buft;
buft = apir_decode_ggml_buffer_type(dec);
size_t value = buft->iface.get_max_size(buft);
apir_encode_size_t(enc, &value);
return 0;
}
uint32_t backend_buffer_type_is_host(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
GGML_UNUSED(ctx);
ggml_backend_buffer_type_t buft;
buft = apir_decode_ggml_buffer_type(dec);
bool is_host = buft->iface.is_host(buft);
apir_encode_bool_t(enc, &is_host);
return 0;
}
uint32_t backend_buffer_type_alloc_buffer(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
GGML_UNUSED(ctx);
ggml_backend_buffer_type_t buft;
buft = apir_decode_ggml_buffer_type(dec);
size_t size;
apir_decode_size_t(dec, &size);
ggml_backend_buffer_t buffer;
buffer = buft->iface.alloc_buffer(buft, size);
apir_encode_ggml_buffer(enc, buffer);
if (buffer) {
apir_track_backend_buffer(buffer);
}
return 0;
}
uint32_t backend_buffer_type_get_alloc_size(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
GGML_UNUSED(ctx);
ggml_backend_buffer_type_t buft;
buft = apir_decode_ggml_buffer_type(dec);
const ggml_tensor * op = apir_decode_ggml_tensor_inplace(dec);
size_t value = buft->iface.get_alloc_size(buft, op);
apir_encode_size_t(enc, &value);
return 0;
}

View File

@ -0,0 +1,131 @@
#include "backend-dispatched.h"
#include "backend-virgl-apir.h"
#include "ggml-backend-impl.h"
#include "ggml-backend.h"
#include "ggml-impl.h"
#include <cstdint>
uint32_t backend_buffer_get_base(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
GGML_UNUSED(ctx);
ggml_backend_buffer_t buffer;
buffer = apir_decode_ggml_buffer(dec);
uintptr_t base = (uintptr_t) buffer->iface.get_base(buffer);
apir_encode_uintptr_t(enc, &base);
return 0;
}
uint32_t backend_buffer_set_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
GGML_UNUSED(ctx);
GGML_UNUSED(enc);
ggml_backend_buffer_t buffer;
buffer = apir_decode_ggml_buffer(dec);
ggml_tensor * tensor;
// safe to remove the const qualifier here
tensor = (ggml_tensor *) (uintptr_t) apir_decode_ggml_tensor(dec);
uint32_t shmem_res_id;
apir_decode_virtgpu_shmem_res_id(dec, &shmem_res_id);
size_t offset;
apir_decode_size_t(dec, &offset);
size_t size;
apir_decode_size_t(dec, &size);
void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id);
if (!shmem_data) {
GGML_LOG_ERROR("Couldn't get the shmem addr from virgl\n");
return 1;
}
buffer->iface.set_tensor(buffer, tensor, shmem_data, offset, size);
return 0;
}
uint32_t backend_buffer_get_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
GGML_UNUSED(ctx);
GGML_UNUSED(enc);
ggml_backend_buffer_t buffer;
buffer = apir_decode_ggml_buffer(dec);
const ggml_tensor * tensor;
// safe to remove the const qualifier here
tensor = apir_decode_ggml_tensor(dec);
uint32_t shmem_res_id;
apir_decode_virtgpu_shmem_res_id(dec, &shmem_res_id);
size_t offset;
apir_decode_size_t(dec, &offset);
size_t size;
apir_decode_size_t(dec, &size);
void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id);
if (!shmem_data) {
GGML_LOG_ERROR("Couldn't get the shmem addr from virgl\n");
return 1;
}
buffer->iface.get_tensor(buffer, tensor, shmem_data, offset, size);
return 0;
}
uint32_t backend_buffer_cpy_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
GGML_UNUSED(ctx);
ggml_backend_buffer_t buffer;
buffer = apir_decode_ggml_buffer(dec);
const ggml_tensor * src;
// safe to remove the const qualifier here
src = apir_decode_ggml_tensor(dec);
ggml_tensor * dst = (ggml_tensor *) (uintptr_t) apir_decode_ggml_tensor(dec);
bool ret = buffer->iface.cpy_tensor(buffer, src, (ggml_tensor *) dst);
apir_encode_bool_t(enc, &ret);
return 0;
}
uint32_t backend_buffer_clear(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
GGML_UNUSED(ctx);
GGML_UNUSED(enc);
ggml_backend_buffer_t buffer;
buffer = apir_decode_ggml_buffer(dec);
uint8_t value;
apir_decode_uint8_t(dec, &value);
buffer->iface.clear(buffer, value);
return 0;
}
uint32_t backend_buffer_free_buffer(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
GGML_UNUSED(ctx);
GGML_UNUSED(enc);
ggml_backend_buffer_t buffer;
buffer = apir_decode_ggml_buffer(dec);
if (!apir_untrack_backend_buffer(buffer)) {
GGML_LOG_WARN("%s: unknown buffer %p\n", __func__, (void *) buffer);
return 1;
}
buffer->iface.free_buffer(buffer);
return 0;
}

View File

@ -0,0 +1,148 @@
#include "backend-dispatched.h"
#include "backend-virgl-apir.h"
#include "ggml-backend-impl.h"
#include "ggml-backend.h"
#include "ggml-impl.h"
#include <cstdint>
uint32_t backend_device_get_device_count(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
GGML_UNUSED(ctx);
GGML_UNUSED(ctx);
GGML_UNUSED(dec);
int32_t dev_count = reg->iface.get_device_count(reg);
apir_encode_int32_t(enc, &dev_count);
return 0;
}
uint32_t backend_device_get_count(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
GGML_UNUSED(ctx);
GGML_UNUSED(ctx);
GGML_UNUSED(dec);
int32_t dev_count = reg->iface.get_device_count(reg);
apir_encode_int32_t(enc, &dev_count);
return 0;
}
uint32_t backend_device_get_name(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
GGML_UNUSED(ctx);
GGML_UNUSED(dec);
const char * string = dev->iface.get_name(dev);
const size_t string_size = strlen(string) + 1;
apir_encode_array_size(enc, string_size);
apir_encode_char_array(enc, string, string_size);
return 0;
}
uint32_t backend_device_get_description(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
GGML_UNUSED(ctx);
GGML_UNUSED(dec);
const char * string = dev->iface.get_description(dev);
const size_t string_size = strlen(string) + 1;
apir_encode_array_size(enc, string_size);
apir_encode_char_array(enc, string, string_size);
return 0;
}
uint32_t backend_device_get_type(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
GGML_UNUSED(ctx);
GGML_UNUSED(dec);
uint32_t type = dev->iface.get_type(dev);
apir_encode_uint32_t(enc, &type);
return 0;
}
uint32_t backend_device_get_memory(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
GGML_UNUSED(ctx);
GGML_UNUSED(dec);
size_t free, total;
dev->iface.get_memory(dev, &free, &total);
apir_encode_size_t(enc, &free);
apir_encode_size_t(enc, &total);
return 0;
}
uint32_t backend_device_supports_op(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
GGML_UNUSED(ctx);
const ggml_tensor * op = apir_decode_ggml_tensor_inplace(dec);
bool supports_op = dev->iface.supports_op(dev, op);
apir_encode_bool_t(enc, &supports_op);
return 0;
}
uint32_t backend_device_get_buffer_type(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
GGML_UNUSED(ctx);
GGML_UNUSED(dec);
ggml_backend_buffer_type_t bufft = dev->iface.get_buffer_type(dev);
apir_encode_ggml_buffer_type(enc, bufft);
return 0;
}
uint32_t backend_device_get_props(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
GGML_UNUSED(ctx);
GGML_UNUSED(dec);
ggml_backend_dev_props props;
dev->iface.get_props(dev, &props);
apir_encode_bool_t(enc, &props.caps.async);
apir_encode_bool_t(enc, &props.caps.host_buffer);
apir_encode_bool_t(enc, &props.caps.buffer_from_host_ptr);
apir_encode_bool_t(enc, &props.caps.events);
return 0;
}
uint32_t backend_device_buffer_from_ptr(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
GGML_UNUSED(ctx);
GGML_UNUSED(dec);
uint32_t shmem_res_id;
apir_decode_virtgpu_shmem_res_id(dec, &shmem_res_id);
void * shmem_ptr = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id);
if (!shmem_ptr) {
GGML_LOG_ERROR("Couldn't get the shmem addr from virgl\n");
apir_decoder_set_fatal(dec);
return 1;
}
size_t size;
apir_decode_size_t(dec, &size);
size_t max_tensor_size;
apir_decode_size_t(dec, &max_tensor_size);
ggml_backend_buffer_t buffer;
buffer = dev->iface.buffer_from_host_ptr(dev, shmem_ptr, size, max_tensor_size);
apir_encode_ggml_buffer(enc, buffer);
apir_encode_ggml_buffer_type(enc, buffer->buft);
if (buffer) {
apir_track_backend_buffer(buffer);
}
return 0;
}

View File

@ -0,0 +1,46 @@
#include "backend-dispatched.h"
#include "backend-virgl-apir.h"
#include "ggml-backend-impl.h"
#include "ggml-backend.h"
#include "ggml-impl.h"
#include <cstdint>
ggml_backend_reg_t reg = NULL;
ggml_backend_dev_t dev = NULL;
ggml_backend_t bck = NULL;
uint64_t timer_start = 0;
uint64_t timer_total = 0;
uint64_t timer_count = 0;
uint32_t backend_dispatch_initialize(void * ggml_backend_reg_fct_p) {
if (reg != NULL) {
GGML_LOG_WARN("%s: already initialized\n", __func__);
return APIR_BACKEND_INITIALIZE_ALREADY_INITED;
}
ggml_backend_reg_t (*ggml_backend_reg_fct)(void) = (ggml_backend_reg_t (*)()) ggml_backend_reg_fct_p;
reg = ggml_backend_reg_fct();
if (reg == NULL) {
GGML_LOG_ERROR("%s: backend registration failed\n", __func__);
return APIR_BACKEND_INITIALIZE_BACKEND_REG_FAILED;
}
if (!reg->iface.get_device_count(reg)) {
GGML_LOG_ERROR("%s: backend initialization failed: no device found\n", __func__);
return APIR_BACKEND_INITIALIZE_NO_DEVICE;
}
dev = reg->iface.get_device(reg, 0);
if (!dev) {
GGML_LOG_ERROR("%s: backend initialization failed: no device received\n", __func__);
return APIR_BACKEND_INITIALIZE_NO_DEVICE;
}
bck = dev->iface.init_backend(dev, NULL);
return APIR_BACKEND_INITIALIZE_SUCCESS;
}

View File

@ -0,0 +1,130 @@
#pragma once
/* device */
uint32_t backend_device_get_device_count(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
uint32_t backend_device_get_count(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
uint32_t backend_device_get_name(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
uint32_t backend_device_get_description(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
uint32_t backend_device_get_type(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
uint32_t backend_device_get_memory(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
uint32_t backend_device_supports_op(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
uint32_t backend_device_get_buffer_type(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
uint32_t backend_device_get_props(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
uint32_t backend_device_buffer_from_ptr(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
/* buffer-type */
uint32_t backend_buffer_type_get_name(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
uint32_t backend_buffer_type_get_alignment(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
uint32_t backend_buffer_type_get_max_size(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
uint32_t backend_buffer_type_is_host(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
uint32_t backend_buffer_type_alloc_buffer(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
uint32_t backend_buffer_type_get_alloc_size(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
/* buffer */
uint32_t backend_buffer_get_base(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
uint32_t backend_buffer_set_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
uint32_t backend_buffer_get_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
uint32_t backend_buffer_cpy_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
uint32_t backend_buffer_clear(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
uint32_t backend_buffer_free_buffer(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
/* backend */
uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
static inline const char * backend_dispatch_command_name(ApirBackendCommandType type) {
switch (type) {
/* device */
case APIR_COMMAND_TYPE_DEVICE_GET_DEVICE_COUNT:
return "backend_device_get_device_count";
case APIR_COMMAND_TYPE_DEVICE_GET_COUNT:
return "backend_device_get_count";
case APIR_COMMAND_TYPE_DEVICE_GET_NAME:
return "backend_device_get_name";
case APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION:
return "backend_device_get_description";
case APIR_COMMAND_TYPE_DEVICE_GET_TYPE:
return "backend_device_get_type";
case APIR_COMMAND_TYPE_DEVICE_GET_MEMORY:
return "backend_device_get_memory";
case APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP:
return "backend_device_supports_op";
case APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE:
return "backend_device_get_buffer_type";
case APIR_COMMAND_TYPE_DEVICE_GET_PROPS:
return "backend_device_get_props";
case APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR:
return "backend_device_buffer_from_ptr";
/* buffer-type */
case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME:
return "backend_buffer_type_get_name";
case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT:
return "backend_buffer_type_get_alignment";
case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE:
return "backend_buffer_type_get_max_size";
case APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST:
return "backend_buffer_type_is_host";
case APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER:
return "backend_buffer_type_alloc_buffer";
case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE:
return "backend_buffer_type_get_alloc_size";
/* buffer */
case APIR_COMMAND_TYPE_BUFFER_GET_BASE:
return "backend_buffer_get_base";
case APIR_COMMAND_TYPE_BUFFER_SET_TENSOR:
return "backend_buffer_set_tensor";
case APIR_COMMAND_TYPE_BUFFER_GET_TENSOR:
return "backend_buffer_get_tensor";
case APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR:
return "backend_buffer_cpy_tensor";
case APIR_COMMAND_TYPE_BUFFER_CLEAR:
return "backend_buffer_clear";
case APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER:
return "backend_buffer_free_buffer";
/* backend */
case APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE:
return "backend_backend_graph_compute";
default:
return "unknown";
}
}
extern "C" {
static const backend_dispatch_t apir_backend_dispatch_table[APIR_BACKEND_DISPATCH_TABLE_COUNT] = {
/* device */
/* APIR_COMMAND_TYPE_DEVICE_GET_DEVICE_COUNT = */ backend_device_get_device_count,
/* APIR_COMMAND_TYPE_DEVICE_GET_COUNT = */ backend_device_get_count,
/* APIR_COMMAND_TYPE_DEVICE_GET_NAME = */ backend_device_get_name,
/* APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION = */ backend_device_get_description,
/* APIR_COMMAND_TYPE_DEVICE_GET_TYPE = */ backend_device_get_type,
/* APIR_COMMAND_TYPE_DEVICE_GET_MEMORY = */ backend_device_get_memory,
/* APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP = */ backend_device_supports_op,
/* APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE = */ backend_device_get_buffer_type,
/* APIR_COMMAND_TYPE_DEVICE_GET_PROPS = */ backend_device_get_props,
/* APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR = */ backend_device_buffer_from_ptr,
/* buffer-type */
/* APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME = */ backend_buffer_type_get_name,
/* APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT = */ backend_buffer_type_get_alignment,
/* APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE = */ backend_buffer_type_get_max_size,
/* APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST = */ backend_buffer_type_is_host,
/* APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER = */ backend_buffer_type_alloc_buffer,
/* APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE = */ backend_buffer_type_get_alloc_size,
/* buffer */
/* APIR_COMMAND_TYPE_BUFFER_GET_BASE = */ backend_buffer_get_base,
/* APIR_COMMAND_TYPE_BUFFER_SET_TENSOR = */ backend_buffer_set_tensor,
/* APIR_COMMAND_TYPE_BUFFER_GET_TENSOR = */ backend_buffer_get_tensor,
/* APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR = */ backend_buffer_cpy_tensor,
/* APIR_COMMAND_TYPE_BUFFER_CLEAR = */ backend_buffer_clear,
/* APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER = */ backend_buffer_free_buffer,
/* backend */
/* APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE = */ backend_backend_graph_compute,
};
}

View File

@ -0,0 +1,23 @@
#pragma once
#include <cstdint>
#include <cstddef>
#include <ggml-backend.h>
#include "backend-convert.h"
#include "backend-virgl-apir.h"
#include "shared/apir_backend.h"
#include "shared/apir_cs.h"
#include "shared/apir_cs_ggml.h"
struct virgl_apir_context {
uint32_t ctx_id;
virgl_apir_callbacks * iface;
};
typedef uint32_t (*backend_dispatch_t)(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
#include "backend-dispatched.gen.h"
uint32_t backend_dispatch_initialize(void * ggml_backend_reg_fct_p);

View File

@ -0,0 +1,32 @@
#pragma once
#include "ggml-backend-impl.h"
#include "ggml-backend.h"
#include "ggml-impl.h"
#include "shared/api_remoting.h"
#include <cstdarg>
#include <cstdio>
#include <cstdlib>
extern ggml_backend_reg_t reg;
extern ggml_backend_dev_t dev;
extern ggml_backend_t bck;
struct virgl_apir_callbacks {
const char * (*get_config)(uint32_t virgl_ctx_id, const char * key);
void * (*get_shmem_ptr)(uint32_t virgl_ctx_id, uint32_t res_id);
};
extern "C" {
ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct virgl_apir_callbacks *virgl_cbs);
void apir_backend_deinit(uint32_t virgl_ctx_id);
uint32_t apir_backend_dispatcher(uint32_t virgl_ctx_id,
virgl_apir_callbacks * virgl_cbs,
uint32_t cmd_type,
char * dec_cur,
const char * dec_end,
char * enc_cur,
const char * enc_end,
char ** enc_cur_after);
}

View File

@ -0,0 +1,148 @@
#include "backend-dispatched.h"
#include "backend-virgl-apir.h"
#include "shared/api_remoting.h"
#include "shared/apir_backend.h"
#include "shared/apir_cs.h"
#include <dlfcn.h>
#include <ggml-backend.h>
#include <iostream>
#define APIR_LLAMA_CPP_GGML_LIBRARY_PATH_ENV "APIR_LLAMA_CPP_GGML_LIBRARY_PATH"
#define APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV "APIR_LLAMA_CPP_GGML_LIBRARY_REG"
#define APIR_LLAMA_CPP_LOG_TO_FILE_ENV "APIR_LLAMA_CPP_LOG_TO_FILE"
#define GGML_DEFAULT_BACKEND_REG "ggml_backend_init"
static void * backend_library_handle = NULL;
static FILE * apir_logfile = NULL;
static void log_to_file_callback(enum ggml_log_level level, const char * text, void * user_data) {
FILE * logfile = (FILE *)user_data;
fprintf(logfile, "[%d] %s", level, text);
fflush(logfile);
}
extern "C" {
void apir_backend_deinit(uint32_t virgl_ctx_id) {
GGML_UNUSED(virgl_ctx_id);
auto buffers = apir_get_track_backend_buffers();
for (const auto & buffer : buffers) {
apir_untrack_backend_buffer(buffer);
buffer->iface.free_buffer(buffer);
}
if (dev) {
size_t free, total;
dev->iface.get_memory(dev, &free, &total);
GGML_LOG_INFO("%s: free memory: %ld MB\n", __func__, (size_t) free / 1024 / 1024);
}
if (backend_library_handle) {
GGML_LOG_INFO("%s: The GGML backend library was loaded. Unloading it.\n", __func__);
dlclose(backend_library_handle);
backend_library_handle = NULL;
}
if (apir_logfile) {
fclose(apir_logfile);
apir_logfile = NULL;
}
}
#define APIR_GGML_LIBRARY_PATH_KEY "ggml.library.path"
#define APIR_GGML_LIBRARY_REG_KEY "ggml.library.reg"
ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct virgl_apir_callbacks *virgl_cbs) {
const char * dlsym_error;
const char * apir_log_to_file = getenv(APIR_LLAMA_CPP_LOG_TO_FILE_ENV);
if (apir_log_to_file) {
apir_logfile = fopen(apir_log_to_file, "w");
if (apir_logfile) {
ggml_log_set(log_to_file_callback, apir_logfile);
} else {
GGML_LOG_INFO("Could not open the log file at '%s'\n", apir_log_to_file);
}
}
const char * library_name = virgl_cbs->get_config(virgl_ctx_id, APIR_GGML_LIBRARY_PATH_KEY);
const char * virgl_library_reg = virgl_cbs->get_config(virgl_ctx_id, APIR_GGML_LIBRARY_REG_KEY);
const char * library_reg = virgl_library_reg ? virgl_library_reg : GGML_DEFAULT_BACKEND_REG;
if (!library_name) {
GGML_LOG_ERROR("cannot open the GGML library: env var '%s' not defined\n", APIR_LLAMA_CPP_GGML_LIBRARY_PATH_ENV);
return APIR_LOAD_LIBRARY_ENV_VAR_MISSING;
}
backend_library_handle = dlopen(library_name, RTLD_LAZY);
if (!backend_library_handle) {
GGML_LOG_ERROR("cannot open the GGML library: %s\n", dlerror());
return APIR_LOAD_LIBRARY_CANNOT_OPEN;
}
if (!library_reg) {
GGML_LOG_ERROR("cannot register the GGML library: env var '%s' not defined\n", APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV);
return APIR_LOAD_LIBRARY_ENV_VAR_MISSING;
}
void * ggml_backend_reg_fct = dlsym(backend_library_handle, library_reg);
dlsym_error = dlerror();
if (dlsym_error) {
GGML_LOG_ERROR("cannot find the GGML backend registration symbol '%s' (from %s): %s\n", library_reg,
APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV, dlsym_error);
return APIR_LOAD_LIBRARY_SYMBOL_MISSING;
}
uint32_t ret = backend_dispatch_initialize(ggml_backend_reg_fct);
return (ApirLoadLibraryReturnCode) (APIR_LOAD_LIBRARY_INIT_BASE_INDEX + ret);
}
uint32_t apir_backend_dispatcher(uint32_t virgl_ctx_id,
virgl_apir_callbacks * virgl_cbs,
uint32_t cmd_type,
char * dec_cur,
const char * dec_end,
char * enc_cur,
const char * enc_end,
char ** enc_cur_after) {
apir_encoder enc = {
.cur = enc_cur,
.start = enc_cur,
.end = enc_end,
.fatal = false,
};
apir_decoder dec = {
.cur = dec_cur,
.end = dec_end,
.fatal = false,
};
virgl_apir_context ctx = {
.ctx_id = virgl_ctx_id,
.iface = virgl_cbs,
};
if (cmd_type >= APIR_BACKEND_DISPATCH_TABLE_COUNT) {
GGML_LOG_ERROR("Received an invalid dispatch index (%d >= %d)\n", cmd_type, APIR_BACKEND_DISPATCH_TABLE_COUNT);
return APIR_BACKEND_FORWARD_INDEX_INVALID;
}
backend_dispatch_t forward_fct = apir_backend_dispatch_table[cmd_type];
uint32_t ret = forward_fct(&enc, &dec, &ctx);
*enc_cur_after = enc.cur;
return ret;
}
}

View File

@ -0,0 +1,90 @@
#pragma once
/* the rest of this file must match virglrenderer/src/apir-protocol.h */
#include <unistd.h>
#include <cstdint>
#define APIR_PROTOCOL_MAJOR 0
#define APIR_PROTOCOL_MINOR 1
#define APIR_HANDSHAKE_MAGIC 0xab1e
enum ApirCommandType {
APIR_COMMAND_TYPE_HANDSHAKE = 0,
APIR_COMMAND_TYPE_LOADLIBRARY = 1,
APIR_COMMAND_TYPE_FORWARD = 2,
APIR_COMMAND_TYPE_LENGTH = 3,
};
typedef uint64_t ApirCommandFlags;
enum ApirLoadLibraryReturnCode {
APIR_LOAD_LIBRARY_SUCCESS = 0,
APIR_LOAD_LIBRARY_HYPERCALL_INITIALIZATION_ERROR = 1,
APIR_LOAD_LIBRARY_ALREADY_LOADED = 2,
APIR_LOAD_LIBRARY_ENV_VAR_MISSING = 3,
APIR_LOAD_LIBRARY_CANNOT_OPEN = 4,
APIR_LOAD_LIBRARY_SYMBOL_MISSING = 5,
APIR_LOAD_LIBRARY_INIT_BASE_INDEX = 6, // anything above this is a APIR backend library initialization return code
};
enum ApirForwardReturnCode {
APIR_FORWARD_SUCCESS = 0,
APIR_FORWARD_NO_DISPATCH_FCT = 1,
APIR_FORWARD_TIMEOUT = 2,
APIR_FORWARD_BASE_INDEX = 3, // anything above this is a APIR backend library forward return code
} ;
__attribute__((unused)) static inline const char * apir_command_name(ApirCommandType type) {
switch (type) {
case APIR_COMMAND_TYPE_HANDSHAKE:
return "HandShake";
case APIR_COMMAND_TYPE_LOADLIBRARY:
return "LoadLibrary";
case APIR_COMMAND_TYPE_FORWARD:
return "Forward";
default:
return "unknown";
}
}
__attribute__((unused)) static const char * apir_load_library_error(ApirLoadLibraryReturnCode code) {
#define APIR_LOAD_LIBRARY_ERROR(code_name) \
do { \
if (code == code_name) \
return #code_name; \
} while (0)
APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_SUCCESS);
APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_HYPERCALL_INITIALIZATION_ERROR);
APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_ALREADY_LOADED);
APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_ENV_VAR_MISSING);
APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_CANNOT_OPEN);
APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_SYMBOL_MISSING);
APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_INIT_BASE_INDEX);
return "Unknown APIR_COMMAND_TYPE_LoadLibrary error";
#undef APIR_LOAD_LIBRARY_ERROR
}
__attribute__((unused)) static const char * apir_forward_error(ApirForwardReturnCode code) {
#define APIR_FORWARD_ERROR(code_name) \
do { \
if (code == code_name) \
return #code_name; \
} while (0)
APIR_FORWARD_ERROR(APIR_FORWARD_SUCCESS);
APIR_FORWARD_ERROR(APIR_FORWARD_NO_DISPATCH_FCT);
APIR_FORWARD_ERROR(APIR_FORWARD_TIMEOUT);
APIR_FORWARD_ERROR(APIR_FORWARD_BASE_INDEX);
return "Unknown APIR_COMMAND_TYPE_FORWARD error";
#undef APIR_FORWARD_ERROR
}

View File

@ -0,0 +1,36 @@
typedef enum ApirBackendCommandType {
/* device */
APIR_COMMAND_TYPE_DEVICE_GET_DEVICE_COUNT = 0,
APIR_COMMAND_TYPE_DEVICE_GET_COUNT = 1,
APIR_COMMAND_TYPE_DEVICE_GET_NAME = 2,
APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION = 3,
APIR_COMMAND_TYPE_DEVICE_GET_TYPE = 4,
APIR_COMMAND_TYPE_DEVICE_GET_MEMORY = 5,
APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP = 6,
APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE = 7,
APIR_COMMAND_TYPE_DEVICE_GET_PROPS = 8,
APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR = 9,
/* buffer-type */
APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME = 10,
APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT = 11,
APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE = 12,
APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST = 13,
APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER = 14,
APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE = 15,
/* buffer */
APIR_COMMAND_TYPE_BUFFER_GET_BASE = 16,
APIR_COMMAND_TYPE_BUFFER_SET_TENSOR = 17,
APIR_COMMAND_TYPE_BUFFER_GET_TENSOR = 18,
APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR = 19,
APIR_COMMAND_TYPE_BUFFER_CLEAR = 20,
APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER = 21,
/* backend */
APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE = 22,
// last command_type index + 1
APIR_BACKEND_DISPATCH_TABLE_COUNT = 23,
} ApirBackendCommandType;

View File

@ -0,0 +1,46 @@
#pragma once
#include "apir_backend.gen.h"
#include <stdint.h> // for uintptr_t
#include <time.h> // for timespec, clock_gettime
#define APIR_BACKEND_INITIALIZE_SUCCESS 0
#define APIR_BACKEND_INITIALIZE_CANNOT_OPEN_BACKEND_LIBRARY 1
#define APIR_BACKEND_INITIALIZE_CANNOT_OPEN_GGML_LIBRARY 2
#define APIR_BACKEND_INITIALIZE_MISSING_BACKEND_SYMBOLS 3
#define APIR_BACKEND_INITIALIZE_MISSING_GGML_SYMBOLS 4
#define APIR_BACKEND_INITIALIZE_BACKEND_FAILED 5
#define APIR_BACKEND_INITIALIZE_BACKEND_REG_FAILED 6
#define APIR_BACKEND_INITIALIZE_ALREADY_INITED 7
#define APIR_BACKEND_INITIALIZE_NO_DEVICE 8
// new entries here need to be added to the apir_backend_initialize_error function below
#define APIR_BACKEND_FORWARD_INDEX_INVALID 6
// 0 is fast, 1 avoids the backend to crash if an unsupported tensor is received
#define APIR_BACKEND_CHECK_SUPPORTS_OP 0
typedef uintptr_t apir_buffer_type_host_handle_t;
typedef uintptr_t apir_buffer_host_handle_t;
static const char * apir_backend_initialize_error(int code) {
#define APIR_BACKEND_INITIALIZE_ERROR(code_name) \
do { \
if (code == code_name) \
return #code_name; \
} while (0)
APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_SUCCESS);
APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_CANNOT_OPEN_BACKEND_LIBRARY);
APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_CANNOT_OPEN_GGML_LIBRARY);
APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_MISSING_BACKEND_SYMBOLS);
APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_MISSING_GGML_SYMBOLS);
APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_BACKEND_FAILED);
return "Unknown APIR_BACKEND_INITIALIZE error:/";
#undef APIR_BACKEND_INITIALIZE_ERROR
}

View File

@ -0,0 +1,383 @@
#pragma once
#include "ggml-impl.h"
#include <cassert>
#include <cstring>
#define likely(x) __builtin_expect(!!(x), 1)
#define unlikely(x) __builtin_expect(!!(x), 0)
struct apir_encoder {
char * cur;
const char * start;
const char * end;
bool fatal;
};
struct apir_decoder {
const char * cur;
const char * end;
bool fatal;
};
/*
* new encoder and decoder
*/
static apir_decoder apir_new_decoder(const char * ptr, size_t size) {
apir_decoder dec = {
.cur = ptr,
.end = ptr + size,
.fatal = false,
};
return dec;
}
static apir_encoder apir_new_encoder(char * ptr, size_t size) {
apir_encoder enc = {
.cur = ptr,
.start = ptr,
.end = ptr + size,
.fatal = false,
};
return enc;
}
/*
* fatal flag handling
*/
static inline void apir_encoder_reset_fatal(apir_encoder * enc) {
enc->fatal = false;
}
static inline void apir_encoder_set_fatal(apir_encoder * enc) {
enc->fatal = true;
}
static inline bool apir_encoder_get_fatal(const apir_encoder * enc) {
return enc->fatal;
}
static inline void apir_decoder_reset_fatal(apir_decoder * dec) {
dec->fatal = false;
}
static inline void apir_decoder_set_fatal(apir_decoder * dec) {
dec->fatal = true;
}
static inline bool apir_decoder_get_fatal(const apir_decoder * dec) {
return dec->fatal;
}
/*
* encode peek
*/
static inline bool apir_decoder_peek_internal(apir_decoder * dec,
size_t size,
void * val,
size_t val_size) {
assert(val_size <= size);
if (unlikely(size > (size_t) (dec->end - dec->cur))) {
GGML_LOG_ERROR("reading too much from the decoder ...\n");
apir_decoder_set_fatal(dec);
memset(val, 0, val_size);
return false;
}
/* we should not rely on the compiler to optimize away memcpy... */
memcpy(val, dec->cur, val_size);
return true;
}
static inline void apir_decoder_peek(apir_decoder * dec, size_t size, void * val, size_t val_size) {
apir_decoder_peek_internal(dec, size, val, val_size);
}
static inline const void * apir_decoder_use_inplace(apir_decoder * dec, size_t size) {
if (unlikely(size > (size_t) (dec->end - dec->cur))) {
GGML_LOG_ERROR("reading too much from the decoder ...\n");
apir_decoder_set_fatal(dec);
return NULL;
}
const void * addr = dec->cur;
dec->cur += size;
return addr;
}
/*
* read/write
*/
static inline void apir_decoder_read(apir_decoder * dec, size_t size, void * val, size_t val_size) {
if (apir_decoder_peek_internal(dec, size, val, val_size)) {
dec->cur += size;
}
}
static inline char * apir_encoder_write(apir_encoder * enc, size_t size, const void * val, size_t val_size) {
assert(val_size <= size);
assert(size <= ((size_t) (enc->end - enc->cur)));
char * write_addr = enc->cur;
/* we should not rely on the compiler to optimize away memcpy... */
memcpy(write_addr, val, val_size);
enc->cur += size;
return write_addr;
}
/*
* encode/decode
*/
static inline void apir_decode(apir_decoder * dec, size_t size, void * data, size_t data_size) {
assert(size % 4 == 0);
apir_decoder_read(dec, size, data, data_size);
}
static inline void apir_encode(apir_encoder * enc, size_t size, const void * data, size_t data_size) {
assert(size % 4 == 0);
apir_encoder_write(enc, size, data, data_size);
}
/*
* typed encode/decode
*/
/* uint8_t */
static inline void apir_encode_uint8_t(apir_encoder * enc, const uint8_t * val) {
apir_encode(enc, sizeof(int), val, sizeof(*val));
}
static inline void apir_decode_uint8_t(apir_decoder * dec, uint8_t * val) {
apir_decode(dec, sizeof(int), val, sizeof(*val));
}
/* uint64_t */
static inline void apir_encode_uint64_t(apir_encoder * enc, const uint64_t * val) {
apir_encode(enc, 8, val, sizeof(*val));
}
static inline void apir_decode_uint64_t(apir_decoder * dec, uint64_t * val) {
apir_decode(dec, 8, val, sizeof(*val));
}
static inline void apir_encode_uint64_t_array(apir_encoder * enc, const uint64_t * val, uint32_t count) {
const size_t size = sizeof(*val) * count;
assert(size >= count);
apir_encode(enc, size, val, size);
}
static inline void apir_decode_uint64_t_array(apir_decoder * dec, uint64_t * val, uint32_t count) {
const size_t size = sizeof(*val) * count;
assert(size >= count);
apir_decode(dec, size, val, size);
}
static inline const uint64_t * apir_decode_uint64_t_array_inplace(apir_decoder * dec, uint32_t count) {
return (uint64_t *) (uintptr_t) apir_decoder_use_inplace(dec, count * sizeof(uint64_t));
}
/* int32_t */
static inline void apir_encode_int32_t(apir_encoder * enc, const int32_t * val) {
apir_encode(enc, 4, val, sizeof(*val));
}
static inline void apir_decode_int32_t(apir_decoder * dec, int32_t * val) {
apir_decode(dec, 4, val, sizeof(*val));
}
static inline void apir_encode_int32_t_array(apir_encoder * enc, const int32_t * val, uint32_t count) {
const size_t size = sizeof(*val) * count;
assert(size >= count);
apir_encode(enc, size, val, size);
}
static inline void apir_decode_int32_t_array(apir_decoder * dec, int32_t * val, uint32_t count) {
const size_t size = sizeof(*val) * count;
assert(size >= count);
apir_decode(dec, size, val, size);
}
/* array size (uint64_t) */
static inline void apir_encode_array_size(apir_encoder * enc, uint64_t size) {
apir_encode_uint64_t(enc, &size);
}
static inline uint64_t apir_decode_array_size(apir_decoder * dec, uint64_t expected_size) {
uint64_t size;
apir_decode_uint64_t(dec, &size);
if (size != expected_size) {
GGML_LOG_ERROR("Couldn't decode array from the decoder\n");
apir_decoder_set_fatal(dec);
size = 0;
}
return size;
}
static inline uint64_t apir_decode_array_size_unchecked(apir_decoder * dec) {
uint64_t size;
apir_decode_uint64_t(dec, &size);
return size;
}
/* non-array pointer */
static inline bool apir_encode_simple_pointer(apir_encoder * enc, const void * val) {
apir_encode_array_size(enc, val ? 1 : 0);
return val;
}
static inline bool apir_decode_simple_pointer(apir_decoder * dec) {
return apir_decode_array_size_unchecked(dec);
}
/* uint32_t */
static inline void apir_encode_uint32_t(apir_encoder * enc, const uint32_t * val) {
apir_encode(enc, 4, val, sizeof(*val));
}
static inline void apir_decode_uint32_t(apir_decoder * dec, uint32_t * val) {
apir_decode(dec, 4, val, sizeof(*val));
}
static inline void apir_encode_uint32_t_array(apir_encoder * enc, const uint32_t * val, uint32_t count) {
const size_t size = sizeof(*val) * count;
assert(size >= count);
apir_encode(enc, size, val, size);
}
static inline void apir_decode_uint32_t_array(apir_decoder * dec, uint32_t * val, uint32_t count) {
const size_t size = sizeof(*val) * count;
assert(size >= count);
apir_decode(dec, size, val, size);
}
/* size_t */
static inline void apir_encode_size_t(apir_encoder * enc, const size_t * val) {
const uint64_t tmp = *val;
apir_encode_uint64_t(enc, &tmp);
}
static inline void apir_decode_size_t(apir_decoder * dec, size_t * val) {
uint64_t tmp;
apir_decode_uint64_t(dec, &tmp);
*val = tmp;
}
static inline void apir_encode_size_t_array(apir_encoder * enc, const size_t * val, uint32_t count) {
if (sizeof(size_t) == sizeof(uint64_t)) {
apir_encode_uint64_t_array(enc, (const uint64_t *) val, count);
} else {
for (uint32_t i = 0; i < count; i++) {
apir_encode_size_t(enc, &val[i]);
}
}
}
static inline void apir_decode_size_t_array(apir_decoder * dec, size_t * val, uint32_t count) {
if (sizeof(size_t) == sizeof(uint64_t)) {
apir_decode_uint64_t_array(dec, (uint64_t *) val, count);
} else {
for (uint32_t i = 0; i < count; i++) {
apir_decode_size_t(dec, &val[i]);
}
}
}
/* opaque blob */
static inline void apir_encode_blob_array(apir_encoder * enc, const void * val, size_t size) {
apir_encode(enc, (size + 3) & ~3, val, size);
}
static inline void apir_decode_blob_array(apir_decoder * dec, void * val, size_t size) {
apir_decode(dec, (size + 3) & ~3, val, size);
}
/* string */
static inline void apir_encode_char_array(apir_encoder * enc, const char * val, size_t size) {
assert(size && strlen(val) < size);
apir_encode_blob_array(enc, val, size);
}
static inline void apir_decode_char_array(apir_decoder * dec, char * val, size_t size) {
apir_decode_blob_array(dec, val, size);
if (size) {
val[size - 1] = '\0';
} else {
GGML_LOG_ERROR("Couldn't decode the blog array\n");
apir_decoder_set_fatal(dec);
}
}
/* (temp) buffer allocation */
static inline void * apir_decoder_alloc_array(size_t size, size_t count) {
size_t alloc_size;
if (unlikely(__builtin_mul_overflow(size, count, &alloc_size))) {
GGML_LOG_ERROR("overflow in array allocation of %zu * %zu bytes\n", size, count);
return NULL;
}
return malloc(alloc_size);
}
/* bool */
static inline void apir_encode_bool_t(apir_encoder * enc, const bool * val) {
apir_encode(enc, sizeof(int), val, sizeof(bool));
}
static inline void apir_decode_bool_t(apir_decoder * dec, bool * val) {
apir_decode(dec, sizeof(int), val, sizeof(bool));
}
/* apir_buffer_type_host_handle_t */
static inline void apir_encode_apir_buffer_type_host_handle_t(apir_encoder * enc,
const apir_buffer_type_host_handle_t * val) {
apir_encode(enc, sizeof(apir_buffer_type_host_handle_t), val, sizeof(apir_buffer_type_host_handle_t));
}
static inline void apir_decode_apir_buffer_type_host_handle_t(apir_decoder * dec,
apir_buffer_type_host_handle_t * val) {
apir_decode(dec, sizeof(apir_buffer_type_host_handle_t), val, sizeof(apir_buffer_type_host_handle_t));
}
/* apir_buffer_host_handle_t */
static inline void apir_encode_apir_buffer_host_handle_t(apir_encoder * enc,
const apir_buffer_host_handle_t * val) {
apir_encode(enc, sizeof(apir_buffer_host_handle_t), val, sizeof(apir_buffer_host_handle_t));
}
static inline void apir_decode_apir_buffer_host_handle_t(apir_decoder * dec, apir_buffer_host_handle_t * val) {
apir_decode(dec, sizeof(apir_buffer_host_handle_t), val, sizeof(apir_buffer_host_handle_t));
}
/* uintptr_t */
static inline void apir_encode_uintptr_t(apir_encoder * enc, const uintptr_t * val) {
apir_encode(enc, sizeof(*val), val, sizeof(*val));
}
static inline void apir_decode_uintptr_t(apir_decoder * dec, uintptr_t * val) {
apir_decode(dec, sizeof(*val), val, sizeof(*val));
}

View File

@ -0,0 +1,211 @@
#include "ggml-impl.h"
#include "apir_cs.h"
#include "apir_cs_rpc.h"
// ggml_buffer_to_apir_host_handle(ggml_backend_buffer_t buffer);
static inline void apir_encode_ggml_buffer_host_handle(apir_encoder * enc,
const apir_buffer_host_handle_t * handle);
static inline ggml_backend_buffer_t apir_decode_ggml_buffer(apir_decoder * dec);
/* apir_rpc_tensor */
static inline void apir_encode_rcp_tensor(apir_encoder * enc, const apir_rpc_tensor * apir_rpc_tensor) {
size_t apir_rpc_tensor_size = sizeof(*apir_rpc_tensor);
apir_encode(enc, apir_rpc_tensor_size, apir_rpc_tensor, apir_rpc_tensor_size);
}
static inline apir_rpc_tensor * apir_decode_apir_rpc_tensor_inplace(apir_decoder * dec) {
size_t apir_rpc_tensor_size = sizeof(apir_rpc_tensor);
return (apir_rpc_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, apir_rpc_tensor_size);
}
static inline apir_rpc_tensor * apir_decode_apir_rpc_tensor_array_inplace(apir_decoder * dec,
uint32_t n_tensors) {
size_t apir_rpc_tensor_size = sizeof(apir_rpc_tensor) * n_tensors;
return (apir_rpc_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, apir_rpc_tensor_size);
}
/* ggml_tensor */
static inline void apir_encode_ggml_tensor(apir_encoder * enc, const ggml_tensor * tensor) {
apir_rpc_tensor serialized = apir_serialize_tensor(tensor);
apir_encode_rcp_tensor(enc, &serialized);
}
static inline const ggml_tensor * apir_decode_ggml_tensor(apir_decoder * dec) {
const apir_rpc_tensor * apir_rpc_tensor = apir_decode_apir_rpc_tensor_inplace(dec);
ggml_init_params params{
/*.mem_size =*/ ggml_tensor_overhead(),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
ggml_context * ctx = ggml_init(params);
const ggml_tensor * tensor = apir_deserialize_tensor(ctx, apir_rpc_tensor);
return tensor;
}
/* *** ggml_backend_buffer_type_t *** */
// ggml_backend_buffer_type_t is a POINTER (to a struct).
// Only the host pointer is shared between the host and guest.
// The guest stores it in `buft->context`.
// The host simply writes the pointer address in the buffer variable.
static inline void apir_encode_ggml_buffer_type(apir_encoder * enc, ggml_backend_buffer_type_t buft) {
apir_buffer_type_host_handle_t handle = ggml_buffer_type_to_apir_handle(buft);
apir_encoder_write(enc, sizeof(handle), &handle, sizeof(handle));
}
static inline ggml_backend_buffer_type_t apir_decode_ggml_buffer_type(apir_decoder * dec) {
apir_buffer_type_host_handle_t handle;
apir_decoder_read(dec, sizeof(handle), &handle, sizeof(handle));
return (ggml_backend_buffer_type_t) handle;
}
static inline apir_buffer_type_host_handle_t apir_decode_apir_buffer_type_host_handle(apir_decoder * dec) {
apir_buffer_type_host_handle_t handle;
apir_decoder_read(dec, sizeof(handle), &handle, sizeof(handle));
return handle;
}
/* *** ggml_backend_type_t *** */
// ggml_backend_buffer_t is a POINTER.
// same logic as for ggml_backend_buffer_type_t
static inline void apir_encode_ggml_buffer(apir_encoder * enc, const ggml_backend_buffer_t buffer) {
apir_buffer_host_handle_t handle = BUFFER_TO_HOST_HANDLE(buffer);
apir_encoder_write(enc, sizeof(handle), &handle, sizeof(handle));
}
static inline ggml_backend_buffer_t apir_decode_ggml_buffer(apir_decoder * dec) {
ggml_backend_buffer_t buffer;
size_t buffer_ptr_size = sizeof(buffer);
apir_decoder_read(dec, buffer_ptr_size, &buffer, buffer_ptr_size);
return buffer;
}
/* enum ggml_status */
static inline void apir_encode_ggml_status(apir_encoder * enc, const ggml_status * status) {
apir_encoder_write(enc, sizeof(*status), status, sizeof(*status));
}
static inline void apir_decode_ggml_status(apir_decoder * dec, ggml_status * status) {
apir_decoder_read(dec, sizeof(*status), status, sizeof(*status));
}
/* virtgpu_shmem */
static inline void apir_encode_virtgpu_shmem_res_id(apir_encoder * enc, uint32_t shmem_res_id) {
apir_encode_uint32_t(enc, &shmem_res_id);
}
static inline void apir_decode_virtgpu_shmem_res_id(apir_decoder * dec, uint32_t * shmem_res_id) {
apir_decode_uint32_t(dec, shmem_res_id);
}
/* ggml_cgraph */
static inline size_t apir_serialize_ggml_cgraph(ggml_cgraph * cgraph, std::vector<uint8_t> & cgraph_data) {
apir_serialize_graph(cgraph, cgraph_data);
return cgraph_data.size();
}
static inline void apir_encode_cgraph_data(apir_encoder * enc, std::vector<uint8_t> & cgraph_data) {
size_t cgraph_size = cgraph_data.size();
apir_encode(enc, cgraph_size, cgraph_data.data(), cgraph_size);
}
static inline ggml_cgraph * apir_decode_ggml_cgraph(apir_decoder * dec, size_t cgraph_size) {
GGML_UNUSED(cgraph_size);
uint32_t n_nodes;
apir_decode_uint32_t(dec, &n_nodes);
const uint64_t * nodes = apir_decode_uint64_t_array_inplace(dec, n_nodes);
uint32_t n_tensors;
apir_decode_uint32_t(dec, &n_tensors);
const apir_rpc_tensor * tensors = apir_decode_apir_rpc_tensor_array_inplace(dec, n_tensors);
return apir_deserialize_graph(n_nodes, n_tensors, tensors, nodes);
}
static inline void apir_encode_ggml_buffer_handle(apir_encoder * enc, const apir_buffer_host_handle_t * handle) {
apir_encoder_write(enc, sizeof(*handle), &handle, sizeof(*handle));
}
static inline void apir_encode_ggml_tensor_inline(apir_encoder * enc, const ggml_tensor * tensor) {
size_t tensor_size = sizeof(*tensor);
if (tensor->extra) {
GGML_ABORT("Cannot pass tensors with extra");
}
if (tensor->src[0] && tensor->buffer) {
static int first = 1;
if (first) {
GGML_LOG_WARN("Cannot pass tensors with src and buffer\n");
first = 0;
}
}
apir_encoder_write(enc, tensor_size, tensor, tensor_size);
// tensor->data is a pointer inside the device buffer. No need to touch it
// tensor->buffer is a pointer to a buffer. Encoding the buffer handle in sequence.
// (could also make a copy of the tensor, and update locally.)
if (tensor->buffer) {
apir_buffer_host_handle_t buffer_handle = ggml_buffer_to_apir_handle(tensor->buffer);
apir_encode_ggml_buffer_handle(enc, &buffer_handle);
}
if (tensor->view_src) {
apir_encoder_write(enc, tensor_size, tensor->view_src, tensor_size);
}
for (int i = 0; tensor->src[i]; i++) {
const ggml_tensor * tensor_src = tensor->src[i];
apir_encoder_write(enc, tensor_size, tensor_src, tensor_size);
}
}
static inline const ggml_tensor * apir_decode_ggml_tensor_inplace(apir_decoder * dec) {
// it safe to remove the `const` qualifier here, we *do* want to
// modify the shared memory data to fix the `src` pointers.
ggml_tensor * tensor = (ggml_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, sizeof(ggml_tensor));
// tensor->data is a pointer inside the device buffer. No need to touch it
// tensor->buffer is a pointer to a buffer. Decode the buffer handle encoded in sequence.
if (tensor->buffer) {
tensor->buffer = apir_decode_ggml_buffer(dec);
}
if (tensor->view_src) {
ggml_tensor * tensor_view_src = (ggml_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, sizeof(ggml_tensor));
tensor->view_src = tensor_view_src;
}
for (int i = 0; tensor->src[i]; i++) {
ggml_tensor * tensor_src = (ggml_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, sizeof(ggml_tensor));
tensor->src[i] = tensor_src; // overwrite op->src[i] pointer with the actual location of the src tensor
}
return tensor;
}

View File

@ -0,0 +1,54 @@
#include "ggml.h"
#include "ggml-backend-impl.h"
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <cstdint>
// ggml_tensor is serialized into apir_rpc_tensor
struct apir_rpc_tensor {
uint64_t id;
uint32_t type;
uint64_t buffer;
uint32_t ne[GGML_MAX_DIMS];
uint32_t nb[GGML_MAX_DIMS];
uint32_t op;
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
int32_t flags;
uint64_t src[GGML_MAX_SRC];
uint64_t view_src;
uint64_t view_offs;
uint64_t data;
char name[GGML_MAX_NAME];
char padding[4];
};
/* frontend */
apir_rpc_tensor apir_serialize_tensor(const ggml_tensor * tensor);
void apir_serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & output);
/* backend */
void apir_track_backend_buffer(ggml_backend_buffer_t buffer);
bool apir_untrack_backend_buffer(ggml_backend_buffer_t buffer);
std::unordered_set<ggml_backend_buffer_t> apir_get_track_backend_buffers();
void apir_add_tensor(ggml_tensor * tensor,
std::vector<apir_rpc_tensor> & tensors,
std::unordered_set<ggml_tensor *> & visited);
ggml_tensor * apir_deserialize_tensor(ggml_context * ctx, const apir_rpc_tensor * tensor);
ggml_tensor * apir_create_node(uint64_t id,
ggml_context * ctx,
const std::unordered_map<uint64_t, const apir_rpc_tensor *> & tensor_ptrs,
std::unordered_map<uint64_t, ggml_tensor *> & tensor_map);
ggml_cgraph * apir_deserialize_graph(uint32_t n_nodes,
uint32_t n_tensors,
const apir_rpc_tensor * tensors,
const uint64_t * nodes);

View File

@ -0,0 +1,98 @@
#include "ggml-remoting.h"
static ggml_backend_buffer_t ggml_backend_remoting_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
size_t size) {
virtgpu * gpu = BUFT_TO_GPU(buft);
ggml_backend_remoting_buffer_context * context = (ggml_backend_remoting_buffer_context *) malloc(sizeof(*context));
if (!context) {
GGML_ABORT("Couldn't allocate the buffer context ...");
}
context->gpu = gpu;
bool async__unused, host_buffer__unused, events__unused;
bool buffer_from_host_ptr;
apir_device_get_props(gpu, &async__unused, &host_buffer__unused, &buffer_from_host_ptr, &events__unused);
if (buffer_from_host_ptr) {
context->apir_context = apir_device_buffer_from_ptr(gpu, size, size);
context->base = context->apir_context.shmem.mmap_ptr;
context->is_from_ptr = true;
} else {
context->apir_context = apir_buffer_type_alloc_buffer(gpu, buft, size);
context->is_from_ptr = false;
context->base = NULL;
}
ggml_backend_buffer_t buffer =
ggml_backend_buffer_init(buft, ggml_backend_remoting_buffer_interface, (void *) context, size);
return buffer;
}
static const char * ggml_backend_remoting_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
virtgpu * gpu = BUFT_TO_GPU(buft);
return apir_buffer_type_get_name(gpu, buft);
}
static size_t ggml_backend_remoting_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
virtgpu * gpu = BUFT_TO_GPU(buft);
static size_t align = 0;
if (align == 0) {
align = apir_buffer_type_get_alignment(gpu, buft);
}
return align;
}
static size_t ggml_backend_remoting_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
virtgpu * gpu = BUFT_TO_GPU(buft);
static size_t max_size = 0;
if (max_size == 0) {
max_size = apir_buffer_type_get_max_size(gpu, buft);
}
return max_size;
}
static bool ggml_backend_remoting_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
virtgpu * gpu = BUFT_TO_GPU(buft);
return apir_buffer_type_is_host(gpu, buft);
}
static size_t ggml_backend_remoting_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft,
const ggml_tensor * tensor) {
virtgpu * gpu = BUFT_TO_GPU(buft);
if (tensor->buffer == NULL
|| !tensor->buffer->context
|| !buft->device->iface.supports_buft(buft->device, tensor->buffer->buft)) {
return ggml_nbytes(tensor);
}
return apir_buffer_type_get_alloc_size(gpu, buft, tensor);
}
const ggml_backend_buffer_type_i ggml_backend_remoting_buffer_type_interface = {
/* .get_name = */ ggml_backend_remoting_buffer_type_get_name,
/* .alloc_buffer = */ ggml_backend_remoting_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_remoting_buffer_type_get_alignment,
/* .get_max_size = */ ggml_backend_remoting_buffer_type_get_max_size,
/* .get_alloc_size = */ ggml_backend_remoting_buffer_type_get_alloc_size,
/* .is_host = */ NULL,
};
const ggml_backend_buffer_type_i ggml_backend_remoting_buffer_from_ptr_type_interface = {
/* .get_name = */ ggml_backend_remoting_buffer_type_get_name,
/* .alloc_buffer = */ NULL,
/* .get_alignment = */ ggml_backend_remoting_buffer_type_get_alignment,
/* .get_max_size = */ ggml_backend_remoting_buffer_type_get_max_size,
/* .get_alloc_size = */ ggml_backend_remoting_buffer_type_get_alloc_size,
/* .is_host = */ NULL,
};

View File

@ -0,0 +1,119 @@
#include "ggml-remoting.h"
#define BUFFER_TO_GPU(name) ((ggml_backend_remoting_buffer_context *) (name)->context)->gpu
static void * ggml_backend_remoting_buffer_get_base(ggml_backend_buffer_t buffer) {
ggml_backend_remoting_buffer_context * context = (ggml_backend_remoting_buffer_context *) buffer->context;
if (context->base) {
return context->base;
}
context->base = apir_buffer_get_base(BUFFER_TO_GPU(buffer), BUFFER_TO_APIR_CONTEXT(buffer));
return context->base;
}
static void ggml_backend_remoting_buffer_set_tensor(ggml_backend_buffer_t buffer,
ggml_tensor * tensor,
const void * data,
size_t offset,
size_t size) {
virtgpu * gpu = BUFFER_TO_GPU(buffer);
ggml_backend_remoting_buffer_context * context = BUFFER_TO_GGML_CONTEXT(buffer);
if (context->is_from_ptr) {
memcpy((char *) tensor->data + offset, data, size);
} else {
apir_buffer_set_tensor(gpu, BUFFER_TO_APIR_CONTEXT(buffer), tensor, data, offset, size);
}
return;
}
static void ggml_backend_remoting_buffer_get_tensor(ggml_backend_buffer_t buffer,
const ggml_tensor * tensor,
void * data,
size_t offset,
size_t size) {
virtgpu * gpu = BUFFER_TO_GPU(buffer);
ggml_backend_remoting_buffer_context * context = BUFFER_TO_GGML_CONTEXT(buffer);
if (context->is_from_ptr) {
memcpy(data, (const char *) tensor->data + offset, size);
} else {
apir_buffer_get_tensor(gpu, BUFFER_TO_APIR_CONTEXT(buffer), tensor, data, offset, size);
}
}
static void ggml_backend_remoting_buffer_set_tensor_from_ptr(ggml_backend_buffer_t buffer,
ggml_tensor * tensor,
const void * data,
size_t offset,
size_t size) {
UNUSED(buffer);
memcpy((char *) tensor->data + offset, data, size);
return;
}
static void ggml_backend_remoting_buffer_get_tensor_from_ptr(ggml_backend_buffer_t buffer,
const ggml_tensor * tensor,
void * data,
size_t offset,
size_t size) {
UNUSED(buffer);
memcpy(data, (const char *) tensor->data + offset, size);
}
static bool ggml_backend_remoting_buffer_cpy_tensor(ggml_backend_buffer_t buffer,
const ggml_tensor * src,
ggml_tensor * dst) {
virtgpu * gpu = BUFFER_TO_GPU(buffer);
bool ret = apir_buffer_cpy_tensor(gpu, BUFFER_TO_APIR_CONTEXT(buffer), src, dst);
return ret;
}
static void ggml_backend_remoting_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
virtgpu * gpu = BUFFER_TO_GPU(buffer);
apir_buffer_clear(gpu, BUFFER_TO_APIR_CONTEXT(buffer), value);
return;
}
static void ggml_backend_remoting_buffer_free_buffer(ggml_backend_buffer_t buffer) {
virtgpu * gpu = BUFFER_TO_GPU(buffer);
apir_buffer_free_buffer(gpu, BUFFER_TO_APIR_CONTEXT(buffer));
ggml_backend_remoting_buffer_context * context = BUFFER_TO_GGML_CONTEXT(buffer);
free(context);
buffer->context = NULL;
}
const ggml_backend_buffer_i ggml_backend_remoting_buffer_interface = {
/* .free_buffer = */ ggml_backend_remoting_buffer_free_buffer,
/* .get_base = */ ggml_backend_remoting_buffer_get_base,
/* .init_tensor = */ NULL,
/* .memset_tensor = */ NULL,
/* .set_tensor = */ ggml_backend_remoting_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_remoting_buffer_get_tensor,
/* .cpy_tensor = */ ggml_backend_remoting_buffer_cpy_tensor,
/* .clear = */ ggml_backend_remoting_buffer_clear,
/* .reset = */ NULL,
};
const ggml_backend_buffer_i ggml_backend_remoting_buffer_from_ptr_interface = {
/* .free_buffer = */ ggml_backend_remoting_buffer_free_buffer,
/* .get_base = */ ggml_backend_remoting_buffer_get_base,
/* .init_tensor = */ NULL,
/* .memset_tensor = */ NULL,
/* .set_tensor = */ ggml_backend_remoting_buffer_set_tensor_from_ptr,
/* .get_tensor = */ ggml_backend_remoting_buffer_get_tensor_from_ptr,
/* .cpy_tensor = */ ggml_backend_remoting_buffer_cpy_tensor,
/* .clear = */ ggml_backend_remoting_buffer_clear,
/* .reset = */ NULL,
};

View File

@ -0,0 +1,144 @@
#include "ggml-remoting.h"
static const char * ggml_backend_remoting_device_get_name(ggml_backend_dev_t dev) {
virtgpu * gpu = DEV_TO_GPU(dev);
return apir_device_get_name(gpu);
}
static const char * ggml_backend_remoting_device_get_description(ggml_backend_dev_t dev) {
virtgpu * gpu = DEV_TO_GPU(dev);
return apir_device_get_description(gpu);
}
static enum ggml_backend_dev_type ggml_backend_remoting_device_get_type(ggml_backend_dev_t dev) {
virtgpu * gpu = DEV_TO_GPU(dev);
static enum ggml_backend_dev_type type;
static bool has_type = false;
if (!has_type) {
has_type = true;
type = (enum ggml_backend_dev_type) apir_device_get_type(gpu);
}
return type;
}
static void ggml_backend_remoting_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
virtgpu * gpu = DEV_TO_GPU(dev);
return apir_device_get_memory(gpu, free, total);
}
static bool ggml_backend_remoting_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
#if USE_ALWAYS_TRUE_SUPPORTS_OP == 1
/* ggml-rpc cheats it like this */
/* with the current implementation of serialize_tensor, the src/view aren't properly passed */
UNUSED(dev);
UNUSED(op);
return true;
#else
virtgpu * gpu = DEV_TO_GPU(dev);
return apir_device_supports_op(gpu, op);
#endif
}
static bool ggml_backend_remoting_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
bool supported = buft->device == dev;
return supported;
}
static bool ggml_backend_remoting_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
UNUSED(dev);
UNUSED(op);
return false;
}
static void ggml_backend_remoting_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
props->name = ggml_backend_remoting_device_get_name(dev);
props->description = ggml_backend_remoting_device_get_description(dev);
props->type = ggml_backend_remoting_device_get_type(dev);
ggml_backend_remoting_device_get_memory(dev, &props->memory_free, &props->memory_total);
virtgpu * gpu = DEV_TO_GPU(dev);
apir_device_get_props(gpu, &props->caps.async, &props->caps.host_buffer, &props->caps.buffer_from_host_ptr,
&props->caps.events);
props->caps.buffer_from_host_ptr = false;
props->caps.async = false;
props->caps.events = false;
}
ggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_type(ggml_backend_dev_t dev) {
virtgpu * gpu = DEV_TO_GPU(dev);
apir_buffer_type_host_handle_t ctx = apir_device_get_buffer_type(gpu);
static ggml_backend_buffer_type buft{
/* .iface = */ ggml_backend_remoting_buffer_type_interface,
/* .device = */ dev,
/* .context = */ (void *) ctx,
};
return &buft;
}
static ggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_from_ptr_type(ggml_backend_dev_t dev) {
virtgpu * gpu = DEV_TO_GPU(dev);
apir_buffer_type_host_handle_t ctx = apir_device_get_buffer_type(gpu);
static ggml_backend_buffer_type buft{
/* .iface = */ ggml_backend_remoting_buffer_from_ptr_type_interface,
/* .device = */ dev,
/* .context = */ (void *) ctx,
};
return &buft;
}
static ggml_backend_buffer_t ggml_backend_remoting_device_buffer_from_ptr(ggml_backend_dev_t dev,
void * ptr,
size_t size,
size_t max_tensor_size) {
virtgpu * gpu = DEV_TO_GPU(dev);
ggml_backend_remoting_buffer_context * context = (ggml_backend_remoting_buffer_context *) malloc(sizeof(*context));
if (!context) {
GGML_ABORT("Couldn't allocate the buffer context ...");
}
context->gpu = gpu;
context->apir_context = apir_device_buffer_from_ptr(gpu, size, max_tensor_size);
context->base = ptr;
context->is_from_ptr = true;
ggml_backend_buffer_t buffer =
ggml_backend_buffer_init(ggml_backend_remoting_device_get_buffer_from_ptr_type(dev),
ggml_backend_remoting_buffer_from_ptr_interface, (void *) context, size);
return buffer;
}
const ggml_backend_device_i ggml_backend_remoting_device_interface = {
/* .get_name = */ ggml_backend_remoting_device_get_name,
/* .get_description = */ ggml_backend_remoting_device_get_description,
/* .get_memory = */ ggml_backend_remoting_device_get_memory,
/* .get_type = */ ggml_backend_remoting_device_get_type,
/* .get_props = */ ggml_backend_remoting_device_get_props,
/* .init_backend = */ ggml_backend_remoting_device_init,
/* .get_buffer_type = */ ggml_backend_remoting_device_get_buffer_type,
/* .get_host_buffer_type = */ NULL,
/* .buffer_from_host_ptr = */ ggml_backend_remoting_device_buffer_from_ptr,
/* .supports_op = */ ggml_backend_remoting_device_supports_op,
/* .supports_buft = */ ggml_backend_remoting_device_supports_buft,
/* .offload_op = */ ggml_backend_remoting_device_offload_op,
/* .event_new = */ NULL,
/* .event_free = */ NULL,
/* .event_synchronize = */ NULL,
};

View File

@ -0,0 +1,137 @@
#include "ggml-remoting.h"
#include "ggml-virtgpu.h"
#include <iostream>
#include <mutex>
static virtgpu * apir_initialize() {
static virtgpu * apir_gpu_instance = NULL;
static bool apir_initialized = false;
{
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
if (apir_initialized) {
return apir_gpu_instance;
}
apir_gpu_instance = create_virtgpu();
if (!apir_gpu_instance) {
GGML_ABORT("failed to initialize the virtgpu");
}
apir_initialized = true;
}
return apir_gpu_instance;
}
static int ggml_backend_remoting_get_device_count() {
virtgpu * gpu = apir_initialize();
if (!gpu) {
GGML_LOG_WARN("apir_initialize failed\n");
return 0;
}
return apir_device_get_count(gpu);
}
static size_t ggml_backend_remoting_reg_get_device_count(ggml_backend_reg_t reg) {
UNUSED(reg);
return ggml_backend_remoting_get_device_count();
}
static std::vector<ggml_backend_dev_t> devices;
ggml_backend_dev_t ggml_backend_remoting_get_device(size_t device) {
GGML_ASSERT(device < devices.size());
return devices[device];
}
static void ggml_backend_remoting_reg_init_devices(ggml_backend_reg_t reg) {
if (devices.size() > 0) {
GGML_LOG_INFO("%s: already initialized\n", __func__);
return;
}
virtgpu * gpu = apir_initialize();
if (!gpu) {
GGML_LOG_ERROR("apir_initialize failed\n");
return;
}
static bool initialized = false;
{
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
if (!initialized) {
for (int i = 0; i < ggml_backend_remoting_get_device_count(); i++) {
ggml_backend_remoting_device_context * ctx = new ggml_backend_remoting_device_context;
char desc[256] = "API Remoting device";
ctx->device = i;
ctx->name = GGML_REMOTING_FRONTEND_NAME + std::to_string(i);
ctx->description = desc;
ctx->gpu = gpu;
ggml_backend_dev_t dev = new ggml_backend_device{
/* .iface = */ ggml_backend_remoting_device_interface,
/* .reg = */ reg,
/* .context = */ ctx,
};
devices.push_back(dev);
}
initialized = true;
}
}
}
static ggml_backend_dev_t ggml_backend_remoting_reg_get_device(ggml_backend_reg_t reg, size_t device) {
UNUSED(reg);
return ggml_backend_remoting_get_device(device);
}
static const char * ggml_backend_remoting_reg_get_name(ggml_backend_reg_t reg) {
UNUSED(reg);
return GGML_REMOTING_FRONTEND_NAME;
}
static const ggml_backend_reg_i ggml_backend_remoting_reg_i = {
/* .get_name = */ ggml_backend_remoting_reg_get_name,
/* .get_device_count = */ ggml_backend_remoting_reg_get_device_count,
/* .get_device = */ ggml_backend_remoting_reg_get_device,
/* .get_proc_address = */ NULL,
};
ggml_backend_reg_t ggml_backend_virtgpu_reg() {
virtgpu * gpu = apir_initialize();
if (!gpu) {
GGML_LOG_ERROR("virtgpu_apir_initialize failed\n");
return NULL;
}
static ggml_backend_reg reg = {
/* .api_version = */ GGML_BACKEND_API_VERSION,
/* .iface = */ ggml_backend_remoting_reg_i,
/* .context = */ gpu,
};
static bool initialized = false;
if (initialized) {
return &reg;
}
initialized = true;
ggml_backend_remoting_reg_init_devices(&reg);
GGML_LOG_INFO("%s: initialized\n", __func__);
return &reg;
}
GGML_BACKEND_DL_IMPL(ggml_backend_virtgpu_reg)

Some files were not shown because too many files have changed in this diff Show More