Merge with master

This commit is contained in:
Nikhil Jain 2026-02-23 19:54:05 -08:00
commit 9f9d06407e
No known key found for this signature in database
62 changed files with 5460 additions and 2508 deletions

View File

@ -1,8 +1,8 @@
ARG UBUNTU_VERSION=24.04
# This needs to generally match the container host's environment.
ARG ROCM_VERSION=7.0
ARG AMDGPU_VERSION=7.0
ARG ROCM_VERSION=7.2
ARG AMDGPU_VERSION=7.2
# Target the ROCm build image
ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete
@ -11,13 +11,12 @@ ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-co
FROM ${BASE_ROCM_DEV_CONTAINER} AS build
# Unless otherwise specified, we make a fat build.
# List from https://github.com/ggml-org/llama.cpp/pull/1087#issuecomment-1682807878
# This is mostly tied to rocBLAS supported archs.
# gfx803, gfx900, gfx906, gfx1032, gfx1101, gfx1102,not officialy supported
# check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-6.4.1/reference/system-requirements.html
# check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-7.2.0/reference/system-requirements.html
# check https://rocm.docs.amd.com/projects/radeon-ryzen/en/latest/docs/compatibility/compatibilityrad/native_linux/native_linux_compatibility.html
# check https://rocm.docs.amd.com/projects/radeon-ryzen/en/latest/docs/compatibility/compatibilityryz/native_linux/native_linux_compatibility.html
ARG ROCM_DOCKER_ARCH='gfx803;gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1010;gfx1030;gfx1032;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx1151'
#ARG ROCM_DOCKER_ARCH='gfx1151'
ARG ROCM_DOCKER_ARCH='gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1151;gfx1150;gfx1200;gfx1201'
# Set ROCm architectures
ENV AMDGPU_TARGETS=${ROCM_DOCKER_ARCH}

View File

@ -516,6 +516,102 @@ jobs:
path: llama-bin-win-sycl-x64.zip
name: llama-bin-win-sycl-x64.zip
ubuntu-22-rocm:
runs-on: ubuntu-22.04
strategy:
matrix:
include:
- ROCM_VERSION: "7.2"
gpu_targets: "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1151;gfx1150;gfx1200;gfx1201"
build: 'x64'
steps:
- name: Clone
id: checkout
uses: actions/checkout@v6
with:
fetch-depth: 0
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
with:
key: ubuntu-rocm-cmake-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}
evict-old-files: 1d
- name: Dependencies
id: depends
run: |
sudo apt install -y build-essential git cmake wget
- name: Setup Legacy ROCm
if: matrix.ROCM_VERSION == '7.2'
id: legacy_env
run: |
sudo mkdir --parents --mode=0755 /etc/apt/keyrings
wget https://repo.radeon.com/rocm/rocm.gpg.key -O - | \
gpg --dearmor | sudo tee /etc/apt/keyrings/rocm.gpg > /dev/null
sudo tee /etc/apt/sources.list.d/rocm.list << EOF
deb [arch=amd64 signed-by=/etc/apt/keyrings/rocm.gpg] https://repo.radeon.com/rocm/apt/${{ matrix.ROCM_VERSION }} jammy main
EOF
sudo tee /etc/apt/preferences.d/rocm-pin-600 << EOF
Package: *
Pin: release o=repo.radeon.com
Pin-Priority: 600
EOF
sudo apt update
sudo apt-get install -y libssl-dev rocm-hip-sdk
- name: Setup TheRock
if: matrix.ROCM_VERSION != '7.2'
id: therock_env
run: |
wget https://repo.amd.com/rocm/tarball/therock-dist-linux-gfx1151-${{ matrix.ROCM_VERSION }}.tar.gz
mkdir install
tar -xf *.tar.gz -C install
export ROCM_PATH=$(pwd)/install
echo ROCM_PATH=$ROCM_PATH >> $GITHUB_ENV
echo PATH=$PATH:$ROCM_PATH/bin >> $GITHUB_ENV
echo LD_LIBRARY_PATH=$ROCM_PATH/lib:$ROCM_PATH/llvm/lib:$ROCM_PATH/lib/rocprofiler-systems >> $GITHUB_ENV
- name: Build with native CMake HIP support
id: cmake_build
run: |
cmake -B build -S . \
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
-DCMAKE_HIP_FLAGS="-mllvm --amdgpu-unroll-threshold-local=600" \
-DCMAKE_BUILD_TYPE=Release \
-DGGML_BACKEND_DL=ON \
-DGGML_NATIVE=OFF \
-DCMAKE_INSTALL_RPATH='$ORIGIN' \
-DCMAKE_BUILD_WITH_INSTALL_RPATH=ON \
-DGGML_CPU_ALL_VARIANTS=ON \
-DGPU_TARGETS="${{ matrix.gpu_targets }}" \
-DGGML_HIP=ON \
-DHIP_PLATFORM=amd \
-DGGML_HIP_ROCWMMA_FATTN=ON \
${{ env.CMAKE_ARGS }}
cmake --build build --config Release -j $(nproc)
- name: Determine tag name
id: tag
uses: ./.github/actions/get-tag-name
- name: Pack artifacts
id: pack_artifacts
run: |
cp LICENSE ./build/bin/
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-rocm-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
- name: Upload artifacts
uses: actions/upload-artifact@v6
with:
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-rocm-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}.tar.gz
name: llama-bin-ubuntu-rocm-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}.tar.gz
windows-hip:
runs-on: windows-2022
@ -784,6 +880,7 @@ jobs:
- windows-cuda
- windows-sycl
- windows-hip
- ubuntu-22-rocm
- ubuntu-22-cpu
- ubuntu-22-vulkan
- macOS-arm64
@ -868,6 +965,7 @@ jobs:
**Linux:**
- [Ubuntu x64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-x64.tar.gz)
- [Ubuntu x64 (Vulkan)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz)
- [Ubuntu x64 (ROCm 7.2)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-rocm-7.2-x64.tar.gz)
- [Ubuntu s390x (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-s390x.tar.gz)
**Windows:**

View File

@ -1,4 +1,4 @@
cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target directories.
cmake_minimum_required(VERSION 3.14...3.28) # for add_link_options and implicit target directories.
project("llama.cpp" C CXX)
include(CheckIncludeFileCXX)

View File

@ -803,7 +803,7 @@ inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, cons
}
// remove potential partial suffix
if (builder.pos() == builder.input().size()) {
if (builder.pos() == builder.input().size() && builder.is_partial()) {
if (unclosed_reasoning_content.empty()) {
rstrip(content);
trim_potential_partial_word(content);

View File

@ -893,23 +893,6 @@ static void common_chat_parse_minimax_m2(common_chat_msg_parser & builder) {
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
}
static void common_chat_parse_qwen3_coder_xml(common_chat_msg_parser & builder) {
static const xml_tool_call_format form = ([]() {
xml_tool_call_format form {};
form.scope_start = "<tool_call>";
form.tool_start = "<function=";
form.tool_sep = ">";
form.key_start = "<parameter=";
form.key_val_sep = ">";
form.val_end = "</parameter>";
form.tool_end = "</function>";
form.scope_end = "</tool_call>";
form.trim_raw_argval = true;
return form;
})();
builder.consume_reasoning_with_xml_tool_calls(form);
}
static void common_chat_parse_kimi_k2(common_chat_msg_parser & builder) {
static const xml_tool_call_format form = ([]() {
xml_tool_call_format form {};
@ -1590,9 +1573,6 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
case COMMON_CHAT_FORMAT_KIMI_K2:
common_chat_parse_kimi_k2(builder);
break;
case COMMON_CHAT_FORMAT_QWEN3_CODER_XML:
common_chat_parse_qwen3_coder_xml(builder);
break;
case COMMON_CHAT_FORMAT_APRIEL_1_5:
common_chat_parse_apriel_1_5(builder);
break;

View File

@ -736,7 +736,6 @@ const char * common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_MINIMAX_M2: return "MiniMax-M2";
case COMMON_CHAT_FORMAT_GLM_4_5: return "GLM 4.5";
case COMMON_CHAT_FORMAT_KIMI_K2: return "Kimi K2";
case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: return "Qwen3 Coder";
case COMMON_CHAT_FORMAT_APRIEL_1_5: return "Apriel 1.5";
case COMMON_CHAT_FORMAT_XIAOMI_MIMO: return "Xiaomi MiMo";
case COMMON_CHAT_FORMAT_SOLAR_OPEN: return "Solar Open";
@ -1522,14 +1521,17 @@ static common_chat_params common_chat_params_init_nemotron_v2(const common_chat_
return data;
}
static common_chat_params common_chat_params_init_nemotron_v3(const common_chat_template & tmpl, const struct templates_params & inputs) {
static common_chat_params common_chat_params_init_qwen3_coder(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.prompt = apply(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_CONSTRUCTED;
// Nemotron Nano 3 and Step-3.5-Flash use the Qwen3 Coder tool calling with thinking
bool supports_reasoning = (tmpl.source().find("<think>") != std::string::npos);
// Handle thinking tags appropriately based on inputs.enable_thinking
if (string_ends_with(data.prompt, "<think>\n")) {
if (supports_reasoning && string_ends_with(data.prompt, "<think>\n")) {
if (!inputs.enable_thinking) {
data.prompt += "</think>";
} else {
@ -1538,19 +1540,21 @@ static common_chat_params common_chat_params_init_nemotron_v3(const common_chat_
}
data.preserved_tokens = {
"<think>",
"</think>",
"<tool_call>",
"</tool_call>",
};
if (supports_reasoning) {
data.preserved_tokens.insert(data.preserved_tokens.end(), {"<think>", "</think>"});
}
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
auto include_grammar = true;
auto parser = build_chat_peg_constructed_parser([&](auto & p) {
auto reasoning = p.eps();
if (inputs.enable_thinking && extract_reasoning) {
if (supports_reasoning && inputs.enable_thinking && extract_reasoning) {
auto reasoning_content = p.reasoning(p.until("</think>")) + ("</think>" | p.end());
if (data.thinking_forced_open) {
reasoning = reasoning_content;
@ -1888,38 +1892,6 @@ static common_chat_params common_chat_params_init_minimax_m2(const common_chat_t
return data;
}
static common_chat_params common_chat_params_init_qwen3_coder_xml(const common_chat_template & tmpl, const struct templates_params & params) {
common_chat_params data;
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.prompt = apply(tmpl, params);
data.format = COMMON_CHAT_FORMAT_QWEN3_CODER_XML;
data.preserved_tokens = {
"<tool_call>",
"</tool_call>",
"<function=",
"</function>",
"<parameter=",
"</parameter>",
};
// build grammar for tool call
static const xml_tool_call_format form {
/* form.scope_start = */ "<tool_call>\n",
/* form.tool_start = */ "<function=",
/* form.tool_sep = */ ">\n",
/* form.key_start = */ "<parameter=",
/* form.key_val_sep = */ ">\n",
/* form.val_end = */ "\n</parameter>\n",
/* form.tool_end = */ "</function>\n",
/* form.scope_end = */ "</tool_call>",
};
build_grammar_xml_tool_call(data, params.tools, form);
return data;
}
static common_chat_params common_chat_params_init_kimi_k2(const common_chat_template & tmpl, const struct templates_params & params) {
common_chat_params data;
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
@ -3147,13 +3119,7 @@ static common_chat_params common_chat_templates_apply_jinja(
src.find("<function=") != std::string::npos &&
src.find("<parameter=") != std::string::npos) {
workaround::func_args_not_string(params.messages);
// Models with <think> support (Step-3.5-Flash, Nemotron 3 Nano) use the
// Nemotron v3 PEG parser for streaming and schema-aware parameter parsing.
// Qwen3-Coder has no <think> in its template.
if (src.find("<think>") != std::string::npos) {
return common_chat_params_init_nemotron_v3(tmpl, params);
}
return common_chat_params_init_qwen3_coder_xml(tmpl, params);
return common_chat_params_init_qwen3_coder(tmpl, params);
}
// Xiaomi MiMo format detection (must come before Hermes 2 Pro)

View File

@ -128,7 +128,6 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_GLM_4_5,
COMMON_CHAT_FORMAT_MINIMAX_M2,
COMMON_CHAT_FORMAT_KIMI_K2,
COMMON_CHAT_FORMAT_QWEN3_CODER_XML,
COMMON_CHAT_FORMAT_APRIEL_1_5,
COMMON_CHAT_FORMAT_XIAOMI_MIMO,
COMMON_CHAT_FORMAT_SOLAR_OPEN,

View File

@ -1760,3 +1760,65 @@ float lr_opt::get_lr(float epoch) const {
LOG_INF("epoch %.2g lr=%.2g\n", epoch, r);
return r;
}
bool common_replay_last_token(struct llama_context * ctx, llama_token last_token, int32_t pos) {
llama_batch batch = llama_batch_get_one(&last_token, 1);
batch.pos = &pos;
if (llama_decode(ctx, batch)) {
LOG_ERR("%s: failed to replay last token\n", __func__);
return false;
}
return true;
}
bool common_prompt_batch_decode(
struct llama_context * ctx,
const std::vector<llama_token> & tokens,
int & n_past,
int n_batch,
std::string_view state_path,
bool save_state) {
const int n_eval = tokens.size();
if (n_eval == 0) {
return true;
}
if (save_state && n_eval > 1) {
const int n_tokens_before_last = n_eval - 1;
GGML_ASSERT(n_eval <= n_batch);
// Decode all but the last token so we can save the memory state before decoding the last token.
// This is done so we can restore the session state later and replay the last token.
// Memory implementations in recurrent/hybrid models don't support removing tokens from their
// memory, so we can't just remove the last token from the memory and replay the last token which
// is the reason for this logic.
if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(tokens.data()), n_tokens_before_last))) {
LOG_ERR("%s : failed to eval\n", __func__);
return false;
}
n_past += n_tokens_before_last;
llama_state_save_file(ctx, state_path.data(), tokens.data(), n_tokens_before_last);
LOG_INF("saved session before last token to %s, n_tokens = %d\n", state_path.data(), n_tokens_before_last);
llama_token last_token = tokens.back();
llama_batch batch = llama_batch_get_one(&last_token, 1);
int32_t pos = n_past;
batch.pos = &pos;
if (llama_decode(ctx, batch)) {
LOG_ERR("%s : failed to eval last token\n", __func__);
return false;
}
n_past++;
} else {
if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(tokens.data()), n_eval))) {
LOG_ERR("%s : failed to eval\n", __func__);
return false;
}
n_past += n_eval;
}
return true;
}

View File

@ -804,6 +804,23 @@ void common_batch_add(
const std::vector<llama_seq_id> & seq_ids,
bool logits);
// decodes a single batch of tokens for a prompt and manages session tokens
//
// Note: We save state before the last token so that we can replay it to ensure
// compatibility with all memory types. Recurrent/hybrid models cannot remove
// tokens from memory, so this approach works across all model architectures.
bool common_prompt_batch_decode(
struct llama_context * ctx,
const std::vector<llama_token> & embd,
int & n_past,
int n_batch,
std::string_view state_path,
bool save_state);
// replays the last token after loading state to regenerate logits
// used after loading session state to ensure the sampling context has valid logits
bool common_replay_last_token(struct llama_context * ctx, llama_token last_token, int32_t pos);
//
// Vocab utils
//

View File

@ -85,7 +85,7 @@ value identifier::execute_impl(context & ctx) {
auto builtins = global_builtins();
if (!it->is_undefined()) {
if (ctx.is_get_stats) {
it->stats.used = true;
value_t::stats_t::mark_used(it);
}
JJ_DEBUG("Identifier '%s' found, type = %s", val.c_str(), it->type().c_str());
return it;
@ -277,7 +277,7 @@ value binary_expression::execute_impl(context & ctx) {
static value try_builtin_func(context & ctx, const std::string & name, value & input, bool undef_on_missing = false) {
JJ_DEBUG("Trying built-in function '%s' for type %s", name.c_str(), input->type().c_str());
if (ctx.is_get_stats) {
input->stats.used = true;
value_t::stats_t::mark_used(input);
input->stats.ops.insert(name);
}
auto builtins = input->get_builtins();
@ -448,7 +448,7 @@ value for_statement::execute_impl(context & ctx) {
// mark the variable being iterated as used for stats
if (ctx.is_get_stats) {
iterable_val->stats.used = true;
value_t::stats_t::mark_used(iterable_val);
iterable_val->stats.ops.insert("array_access");
}
@ -470,7 +470,7 @@ value for_statement::execute_impl(context & ctx) {
items.push_back(std::move(tuple));
}
if (ctx.is_get_stats) {
iterable_val->stats.used = true;
value_t::stats_t::mark_used(iterable_val);
iterable_val->stats.ops.insert("object_access");
}
} else {
@ -480,7 +480,7 @@ value for_statement::execute_impl(context & ctx) {
items.push_back(item);
}
if (ctx.is_get_stats) {
iterable_val->stats.used = true;
value_t::stats_t::mark_used(iterable_val);
iterable_val->stats.ops.insert("array_access");
}
}
@ -817,8 +817,9 @@ value member_expression::execute_impl(context & ctx) {
}
if (ctx.is_get_stats && val && object && property) {
val->stats.used = true;
object->stats.used = true;
value_t::stats_t::mark_used(val);
value_t::stats_t::mark_used(object);
value_t::stats_t::mark_used(property);
if (is_val<value_int>(property)) {
object->stats.ops.insert("array_access");
} else if (is_val<value_string>(property)) {

View File

@ -161,6 +161,11 @@ static value tojson(const func_args & args) {
value val_separators = args.get_kwarg_or_pos("separators", 3);
value val_sort = args.get_kwarg_or_pos("sort_keys", 4);
int indent = -1;
if (args.ctx.is_get_stats) {
// mark as used (recursively) for stats
auto val_input = args.get_pos(0);
value_t::stats_t::mark_used(const_cast<value&>(val_input), true);
}
if (is_val<value_int>(val_indent)) {
indent = static_cast<int>(val_indent->as_int());
}
@ -891,6 +896,11 @@ const func_builtins & value_array_t::get_builtins() const {
}},
{"string", [](const func_args & args) -> value {
args.ensure_vals<value_array>();
if (args.ctx.is_get_stats) {
// mark as used (recursively) for stats
auto val_input = args.get_pos(0);
value_t::stats_t::mark_used(const_cast<value&>(val_input), true);
}
return mk_val<value_string>(args.get_pos(0)->as_string());
}},
{"tojson", tojson},
@ -1046,6 +1056,11 @@ const func_builtins & value_object_t::get_builtins() const {
{"tojson", tojson},
{"string", [](const func_args & args) -> value {
args.ensure_vals<value_object>();
if (args.ctx.is_get_stats) {
// mark as used (recursively) for stats
auto val_input = args.get_pos(0);
value_t::stats_t::mark_used(const_cast<value&>(val_input), true);
}
return mk_val<value_string>(args.get_pos(0)->as_string());
}},
{"length", [](const func_args & args) -> value {
@ -1358,4 +1373,21 @@ std::string value_to_string_repr(const value & val) {
}
}
// stats utility
void value_t::stats_t::mark_used(value & val, bool deep) {
val->stats.used = true;
if (deep) {
if (is_val<value_array>(val)) {
for (auto & item : val->val_arr) {
mark_used(item, deep);
}
} else if (is_val<value_object>(val)) {
for (auto & pair : val->val_obj) {
mark_used(pair.first, deep);
mark_used(pair.second, deep);
}
}
}
}
} // namespace jinja

View File

@ -118,6 +118,8 @@ struct value_t {
bool used = false;
// ops can be builtin calls or operators: "array_access", "object_access"
std::set<std::string> ops;
// utility to recursively mark value and its children as used
static void mark_used(value & val, bool deep = false);
} stats;
value_t() = default;

View File

@ -1274,6 +1274,9 @@ class TextModel(ModelBase):
if chkhsh == "b4b8ca1f9769494fbd956ebc4c249de6131fb277a4a3345a7a92c7dd7a55808d":
# ref: https://huggingface.co/jdopensource/JoyAI-LLM-Flash
res = "joyai-llm"
if chkhsh == "e4d54df1ebc1f2b91acd986c5b51aa50837d5faf7c7398e73c1f9e9ee5d19869":
# ref: https://huggingface.co/kakaocorp/kanana-2-30b-a3b-instruct-2601
res = "kanana2"
if res is None:
logger.warning("\n")

View File

@ -152,6 +152,7 @@ models = [
{"name": "exaone-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/K-EXAONE-236B-A23B", },
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.5-9B-Instruct", },
{"name": "joyai-llm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jdopensource/JoyAI-LLM-Flash", },
{"name": "kanana2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/kakaocorp/kanana-2-30b-a3b-instruct-2601", },
]
# some models are known to be broken upstream, so we will skip them as exceptions

View File

@ -77,7 +77,10 @@ causal-verify-embeddings: causal-run-original-embeddings causal-run-converted-em
@./scripts/causal/compare-embeddings-logits.sh
causal-inspect-original-model:
@./scripts/utils/inspect-org-model.py
@./scripts/utils/inspect-org-model.py --list-all -s
causal-list-original-model-tensors:
@./scripts/utils/inspect-org-model.py --list-all-short -s
causal-inspect-converted-model:
@./scripts/utils/inspect-converted-model.sh
@ -153,7 +156,7 @@ embedding-verify-logits-st: embedding-run-original-model-st embedding-run-conver
embedding-inspect-original-model:
$(call validate_embedding_model_path,embedding-inspect-original-model)
@EMBEDDING_MODEL_PATH="$(EMBEDDING_MODEL_PATH)" ./scripts/utils/inspect-org-model.py -m ${EMBEDDING_MODEL_PATH}
@EMBEDDING_MODEL_PATH="$(EMBEDDING_MODEL_PATH)" ./scripts/utils/inspect-org-model.py -m ${EMBEDDING_MODEL_PATH} --list-all -s
embedding-inspect-converted-model:
@CONVERTED_EMBEDDING_MODEL="$(CONVERTED_EMBEDDING_MODEL)" ./scripts/utils/inspect-converted-model.sh ${CONVERTED_EMBEDDING_MODEL}

View File

@ -1,67 +1,290 @@
#!/usr/bin/env python3
import argparse
import os
import json
import os
import re
import struct
import sys
from pathlib import Path
from typing import Optional
from safetensors import safe_open
from collections import defaultdict
parser = argparse.ArgumentParser(description='Process model with specified path')
parser.add_argument('--model-path', '-m', help='Path to the model')
args = parser.parse_args()
model_path = os.environ.get('MODEL_PATH', args.model_path)
if model_path is None:
parser.error("Model path must be specified either via --model-path argument or MODEL_PATH environment variable")
MODEL_SAFETENSORS_FILE = "model.safetensors"
MODEL_SAFETENSORS_INDEX = "model.safetensors.index.json"
# Check if there's an index file (multi-file model)
index_path = os.path.join(model_path, "model.safetensors.index.json")
single_file_path = os.path.join(model_path, "model.safetensors")
DTYPE_SIZES = {
"F64": 8, "I64": 8, "U64": 8,
"F32": 4, "I32": 4, "U32": 4,
"F16": 2, "BF16": 2, "I16": 2, "U16": 2,
"I8": 1, "U8": 1, "BOOL": 1,
"F8_E4M3": 1, "F8_E5M2": 1,
}
if os.path.exists(index_path):
# Multi-file model
print("Multi-file model detected")
SIZE_UNITS = ['B', 'KB', 'MB', 'GB', 'TB']
with open(index_path, 'r') as f:
index_data = json.load(f)
# Get the weight map (tensor_name -> file_name)
weight_map = index_data.get("weight_map", {})
def get_weight_map(model_path: Path) -> Optional[dict[str, str]]:
index_file = model_path / MODEL_SAFETENSORS_INDEX
# Group tensors by file for efficient processing
file_tensors = defaultdict(list)
for tensor_name, file_name in weight_map.items():
file_tensors[file_name].append(tensor_name)
if index_file.exists():
with open(index_file, 'r') as f:
index = json.load(f)
return index.get("weight_map", {})
print("Tensors in model:")
return None
# Process each shard file
for file_name, tensor_names in file_tensors.items():
file_path = os.path.join(model_path, file_name)
print(f"\n--- From {file_name} ---")
with safe_open(file_path, framework="pt") as f:
for tensor_name in sorted(tensor_names):
tensor = f.get_tensor(tensor_name)
print(f"- {tensor_name} : shape = {tensor.shape}, dtype = {tensor.dtype}")
def get_all_tensor_names(model_path: Path) -> list[str]:
weight_map = get_weight_map(model_path)
elif os.path.exists(single_file_path):
# Single file model (original behavior)
print("Single-file model detected")
if weight_map is not None:
return list(weight_map.keys())
with safe_open(single_file_path, framework="pt") as f:
keys = f.keys()
print("Tensors in model:")
for key in sorted(keys):
tensor = f.get_tensor(key)
print(f"- {key} : shape = {tensor.shape}, dtype = {tensor.dtype}")
single_file = model_path / MODEL_SAFETENSORS_FILE
if single_file.exists():
try:
with safe_open(single_file, framework="pt", device="cpu") as f:
return list(f.keys())
except Exception as e:
print(f"Error reading {single_file}: {e}")
sys.exit(1)
else:
print(f"Error: Neither 'model.safetensors.index.json' nor 'model.safetensors' found in {model_path}")
print("Available files:")
if os.path.exists(model_path):
for item in sorted(os.listdir(model_path)):
print(f" {item}")
print(f"Error: No safetensors files found in {model_path}")
sys.exit(1)
def find_tensor_file(model_path: Path, tensor_name: str) -> Optional[str]:
weight_map = get_weight_map(model_path)
if weight_map is not None:
return weight_map.get(tensor_name)
single_file = model_path / MODEL_SAFETENSORS_FILE
if single_file.exists():
return single_file.name
return None
def read_safetensors_header(file_path: Path) -> dict:
with open(file_path, 'rb') as f:
header_size = struct.unpack('<Q', f.read(8))[0]
return json.loads(f.read(header_size))
def get_tensor_size_bytes(tensor_meta: dict) -> int:
offsets = tensor_meta.get("data_offsets")
if offsets and len(offsets) == 2:
return offsets[1] - offsets[0]
n_elements = 1
for d in tensor_meta.get("shape", []):
n_elements *= d
return n_elements * DTYPE_SIZES.get(tensor_meta.get("dtype", "F32"), 4)
def format_size(size_bytes: int) -> str:
val = float(size_bytes)
for unit in SIZE_UNITS[:-1]:
if val < 1024.0:
return f"{val:.2f} {unit}"
val /= 1024.0
return f"{val:.2f} {SIZE_UNITS[-1]}"
def get_all_tensor_metadata(model_path: Path) -> dict[str, dict]:
weight_map = get_weight_map(model_path)
if weight_map is not None:
file_to_tensors: dict[str, list[str]] = {}
for tensor_name, file_name in weight_map.items():
file_to_tensors.setdefault(file_name, []).append(tensor_name)
all_metadata: dict[str, dict] = {}
for file_name, tensor_names in file_to_tensors.items():
try:
header = read_safetensors_header(model_path / file_name)
for tensor_name in tensor_names:
if tensor_name in header:
all_metadata[tensor_name] = header[tensor_name]
except Exception as e:
print(f"Warning: Could not read header from {file_name}: {e}", file=sys.stderr)
return all_metadata
single_file = model_path / MODEL_SAFETENSORS_FILE
if single_file.exists():
try:
header = read_safetensors_header(single_file)
return {k: v for k, v in header.items() if k != "__metadata__"}
except Exception as e:
print(f"Error reading {single_file}: {e}")
sys.exit(1)
print(f"Error: No safetensors files found in {model_path}")
sys.exit(1)
def normalize_tensor_name(tensor_name: str) -> str:
normalized = re.sub(r'\.\d+\.', '.#.', tensor_name)
normalized = re.sub(r'\.\d+$', '.#', normalized)
return normalized
def list_all_tensors(
model_path: Path,
short: bool = False,
show_sizes: bool = False,
):
tensor_names = get_all_tensor_names(model_path)
metadata: Optional[dict[str, dict]] = None
if show_sizes:
metadata = get_all_tensor_metadata(model_path)
total_bytes = 0
if short:
seen: dict[str, str] = {}
for tensor_name in sorted(tensor_names):
normalized = normalize_tensor_name(tensor_name)
if normalized not in seen:
seen[normalized] = tensor_name
display_pairs = list(sorted(seen.items()))
name_width = max((len(n) for n, _ in display_pairs), default=0)
for normalized, first_name in display_pairs:
if metadata and first_name in metadata:
m = metadata[first_name]
size = get_tensor_size_bytes(m)
total_bytes += size
print(f"{normalized:{name_width}} {m.get('dtype', '?'):6s} {str(m.get('shape', '')):30s} {format_size(size)}")
else:
print(normalized)
else:
print(f" Directory {model_path} does not exist")
exit(1)
name_width = max((len(n) for n in tensor_names), default=0)
for tensor_name in sorted(tensor_names):
if metadata and tensor_name in metadata:
m = metadata[tensor_name]
size = get_tensor_size_bytes(m)
total_bytes += size
print(f"{tensor_name:{name_width}} {m.get('dtype', '?'):6s} {str(m.get('shape', '')):30s} {format_size(size)}")
else:
print(tensor_name)
if show_sizes:
print(f"\nTotal: {format_size(total_bytes)}")
def print_tensor_info(model_path: Path, tensor_name: str, num_values: Optional[int] = None):
tensor_file = find_tensor_file(model_path, tensor_name)
if tensor_file is None:
print(f"Error: Could not find tensor '{tensor_name}' in model index")
print(f"Model path: {model_path}")
sys.exit(1)
file_path = model_path / tensor_file
try:
header = read_safetensors_header(file_path)
tensor_meta = header.get(tensor_name, {})
dtype_str = tensor_meta.get("dtype")
with safe_open(file_path, framework="pt", device="cpu") as f:
if tensor_name in f.keys():
tensor_slice = f.get_slice(tensor_name)
shape = tensor_slice.get_shape()
print(f"Tensor: {tensor_name}")
print(f"File: {tensor_file}")
print(f"Shape: {shape}")
if dtype_str:
print(f"Dtype: {dtype_str}")
if tensor_meta:
print(f"Size: {format_size(get_tensor_size_bytes(tensor_meta))}")
if num_values is not None:
tensor = f.get_tensor(tensor_name)
if not dtype_str:
print(f"Dtype: {tensor.dtype}")
flat = tensor.flatten()
n = min(num_values, flat.numel())
print(f"Values: {flat[:n].tolist()}")
else:
print(f"Error: Tensor '{tensor_name}' not found in {tensor_file}")
sys.exit(1)
except FileNotFoundError:
print(f"Error: The file '{file_path}' was not found.")
sys.exit(1)
except Exception as e:
print(f"An error occurred: {e}")
sys.exit(1)
def main():
parser = argparse.ArgumentParser(
description="Print tensor information from a safetensors model"
)
parser.add_argument(
"tensor_name",
nargs="?",
help="Name of the tensor to inspect"
)
parser.add_argument(
"-m", "--model-path",
type=Path,
help="Path to the model directory (default: MODEL_PATH environment variable)"
)
parser.add_argument(
"-l", "--list-all-short",
action="store_true",
help="List unique tensor patterns (layer numbers replaced with #)"
)
parser.add_argument(
"-la", "--list-all",
action="store_true",
help="List all tensor names with actual layer numbers"
)
parser.add_argument(
"-n", "--num-values",
nargs="?",
const=10,
default=None,
type=int,
metavar="N",
help="Print the first N values of the tensor flattened (default: 10 if flag is given without a number)"
)
parser.add_argument(
"-s", "--sizes",
action="store_true",
help="Show dtype, shape, and size for each tensor when listing"
)
args = parser.parse_args()
model_path = args.model_path
if model_path is None:
model_path_str = os.environ.get("MODEL_PATH")
if model_path_str is None:
print("Error: --model-path not provided and MODEL_PATH environment variable not set")
sys.exit(1)
model_path = Path(model_path_str)
if not model_path.exists():
print(f"Error: Model path does not exist: {model_path}")
sys.exit(1)
if not model_path.is_dir():
print(f"Error: Model path is not a directory: {model_path}")
sys.exit(1)
if args.list_all_short or args.list_all:
list_all_tensors(model_path, short=args.list_all_short, show_sizes=args.sizes)
else:
if args.tensor_name is None:
print("Error: tensor_name is required when not using --list-all-short or --list-all")
sys.exit(1)
print_tensor_info(model_path, args.tensor_name, args.num_values)
if __name__ == "__main__":
main()

View File

@ -1,174 +0,0 @@
#!/usr/bin/env python3
import argparse
import json
import os
import re
import sys
from pathlib import Path
from typing import Optional
from safetensors import safe_open
MODEL_SAFETENSORS_FILE = "model.safetensors"
MODEL_SAFETENSORS_INDEX = "model.safetensors.index.json"
def get_weight_map(model_path: Path) -> Optional[dict[str, str]]:
index_file = model_path / MODEL_SAFETENSORS_INDEX
if index_file.exists():
with open(index_file, 'r') as f:
index = json.load(f)
return index.get("weight_map", {})
return None
def get_all_tensor_names(model_path: Path) -> list[str]:
weight_map = get_weight_map(model_path)
if weight_map is not None:
return list(weight_map.keys())
single_file = model_path / MODEL_SAFETENSORS_FILE
if single_file.exists():
try:
with safe_open(single_file, framework="pt", device="cpu") as f:
return list(f.keys())
except Exception as e:
print(f"Error reading {single_file}: {e}")
sys.exit(1)
print(f"Error: No safetensors files found in {model_path}")
sys.exit(1)
def find_tensor_file(model_path: Path, tensor_name: str) -> Optional[str]:
weight_map = get_weight_map(model_path)
if weight_map is not None:
return weight_map.get(tensor_name)
single_file = model_path / MODEL_SAFETENSORS_FILE
if single_file.exists():
return single_file.name
return None
def normalize_tensor_name(tensor_name: str) -> str:
normalized = re.sub(r'\.\d+\.', '.#.', tensor_name)
normalized = re.sub(r'\.\d+$', '.#', normalized)
return normalized
def list_all_tensors(model_path: Path, unique: bool = False):
tensor_names = get_all_tensor_names(model_path)
if unique:
seen = set()
for tensor_name in sorted(tensor_names):
normalized = normalize_tensor_name(tensor_name)
if normalized not in seen:
seen.add(normalized)
print(normalized)
else:
for tensor_name in sorted(tensor_names):
print(tensor_name)
def print_tensor_info(model_path: Path, tensor_name: str, num_values: Optional[int] = None):
tensor_file = find_tensor_file(model_path, tensor_name)
if tensor_file is None:
print(f"Error: Could not find tensor '{tensor_name}' in model index")
print(f"Model path: {model_path}")
sys.exit(1)
file_path = model_path / tensor_file
try:
with safe_open(file_path, framework="pt", device="cpu") as f:
if tensor_name in f.keys():
tensor_slice = f.get_slice(tensor_name)
shape = tensor_slice.get_shape()
print(f"Tensor: {tensor_name}")
print(f"File: {tensor_file}")
print(f"Shape: {shape}")
if num_values is not None:
tensor = f.get_tensor(tensor_name)
print(f"Dtype: {tensor.dtype}")
flat = tensor.flatten()
n = min(num_values, flat.numel())
print(f"Values: {flat[:n].tolist()}")
else:
print(f"Error: Tensor '{tensor_name}' not found in {tensor_file}")
sys.exit(1)
except FileNotFoundError:
print(f"Error: The file '{file_path}' was not found.")
sys.exit(1)
except Exception as e:
print(f"An error occurred: {e}")
sys.exit(1)
def main():
parser = argparse.ArgumentParser(
description="Print tensor information from a safetensors model"
)
parser.add_argument(
"tensor_name",
nargs="?", # optional (if --list is used for example)
help="Name of the tensor to inspect"
)
parser.add_argument(
"-m", "--model-path",
type=Path,
help="Path to the model directory (default: MODEL_PATH environment variable)"
)
parser.add_argument(
"-l", "--list",
action="store_true",
help="List unique tensor patterns in the model (layer numbers replaced with #)"
)
parser.add_argument(
"-n", "--num-values",
nargs="?",
const=10,
default=None,
type=int,
metavar="N",
help="Print the first N values of the tensor flattened (default: 10 if flag is given without a number)"
)
args = parser.parse_args()
model_path = args.model_path
if model_path is None:
model_path_str = os.environ.get("MODEL_PATH")
if model_path_str is None:
print("Error: --model-path not provided and MODEL_PATH environment variable not set")
sys.exit(1)
model_path = Path(model_path_str)
if not model_path.exists():
print(f"Error: Model path does not exist: {model_path}")
sys.exit(1)
if not model_path.is_dir():
print(f"Error: Model path is not a directory: {model_path}")
sys.exit(1)
if args.list:
list_all_tensors(model_path, unique=True)
else:
if args.tensor_name is None:
print("Error: tensor_name is required when not using --list")
sys.exit(1)
print_tensor_info(model_path, args.tensor_name, args.num_values)
if __name__ == "__main__":
main()

View File

@ -5,12 +5,15 @@
#include <vector>
#include <cstdio>
int main(int argc, char ** argv) {
common_params params;
params.prompt = "The quick brown fox";
params.sampling.seed = 1234;
const std::string_view state_file = "dump_state.bin";
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
return 1;
}
@ -53,35 +56,16 @@ int main(int argc, char ** argv) {
// tokenize prompt
auto tokens = common_tokenize(ctx, params.prompt, true);
// prepare the batch
llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
for (size_t i = 0; i < tokens.size(); i++) {
common_batch_add(batch, tokens[i], i, {0}, false);
const bool save_state = true;
if (!common_prompt_batch_decode(ctx, tokens, n_past, params.n_batch, state_file, save_state)) {
return 1;
}
batch.logits[batch.n_tokens - 1] = true; // generate next token
// evaluate prompt
llama_decode(ctx, batch);
n_past += batch.n_tokens;
// save state (rng, logits, embedding and kv_cache) to file
{
std::vector<uint8_t> state_mem(llama_state_get_size(ctx));
const size_t written = llama_state_get_data(ctx, state_mem.data(), state_mem.size());
FILE *fp_write = fopen("dump_state.bin", "wb");
fwrite(state_mem.data(), 1, written, fp_write);
fclose(fp_write);
fprintf(stderr, "%s : serialized state into %zd out of a maximum of %zd bytes\n", __func__, written, state_mem.size());
}
// save state (last tokens)
const auto n_past_saved = n_past;
// first run
printf("\nfirst run: %s", params.prompt.c_str());
llama_batch batch = llama_batch_init(1, 0, 1);
for (auto i = 0; i < params.n_predict; i++) {
auto next_token = llama_sampler_sample(smpl, ctx, -1);
auto next_token_str = common_token_to_piece(ctx, next_token);
@ -111,27 +95,23 @@ int main(int argc, char ** argv) {
printf("\nsecond run: %s", params.prompt.c_str());
// load state (rng, logits, embedding and kv_cache) from file
{
std::vector<uint8_t> state_mem;
// load state from file
std::vector<llama_token> unused_sts(tokens.size()); // unused session tokens.
size_t n_token_count_out = 0;
FILE * fp_read = fopen("dump_state.bin", "rb");
fseek(fp_read, 0, SEEK_END);
state_mem.resize(ftell(fp_read));
fseek(fp_read, 0, SEEK_SET);
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
fclose(fp_read);
if (read != llama_state_set_data(ctx2, state_mem.data(), state_mem.size())) {
fprintf(stderr, "\n%s : failed to read state\n", __func__);
return 1;
}
fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size());
if (!llama_state_load_file(ctx2, state_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) {
fprintf(stderr, "\n%s : failed to load state\n", __func__);
return 1;
}
fprintf(stderr, "%s : loaded state with %zu tokens\n", __func__, n_token_count_out);
// restore state (last tokens)
n_past = n_past_saved;
n_past = n_token_count_out;
if (!common_replay_last_token(ctx2, tokens.back(), n_past)) {
return 1;
}
++n_past;
// second run
for (auto i = 0; i < params.n_predict; i++) {
@ -160,7 +140,9 @@ int main(int argc, char ** argv) {
}
// make new context
llama_context * ctx3 = llama_init_from_model(model, common_context_params_to_llama(params));
auto params_ctx3 = common_context_params_to_llama(params);
params_ctx3.n_seq_max = 2;
llama_context * ctx3 = llama_init_from_model(model, params_ctx3);
llama_sampler * smpl3 = llama_sampler_chain_init(sparams);
@ -169,26 +151,21 @@ int main(int argc, char ** argv) {
printf("\nsingle seq run: %s", params.prompt.c_str());
// load state (rng, logits, embedding and kv_cache) from file
{
std::vector<uint8_t> state_mem;
n_token_count_out = 0;
FILE * fp_read = fopen("dump_state.bin", "rb");
fseek(fp_read, 0, SEEK_END);
state_mem.resize(ftell(fp_read));
fseek(fp_read, 0, SEEK_SET);
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
fclose(fp_read);
if (read != llama_state_set_data(ctx3, state_mem.data(), state_mem.size())) {
fprintf(stderr, "\n%s : failed to read state\n", __func__);
return 1;
}
fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size());
if (!llama_state_load_file(ctx3, state_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) {
fprintf(stderr, "\n%s : failed to load state\n", __func__);
return 1;
}
fprintf(stderr, "%s : loaded state with %zu tokens\n", __func__, n_token_count_out);
// restore state (last tokens)
n_past = n_past_saved;
n_past = n_token_count_out;
if (!common_replay_last_token(ctx3, tokens.back(), n_past)) {
return 1;
}
++n_past;
// save seq 0 and load into seq 1
{

View File

@ -42,6 +42,7 @@
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
@ -55,9 +56,10 @@
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
@ -77,6 +79,7 @@
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
@ -86,6 +89,7 @@
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
@ -110,6 +114,7 @@
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
@ -123,6 +128,7 @@
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
@ -148,6 +154,7 @@
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
@ -161,6 +168,7 @@
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
@ -187,6 +195,7 @@
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
@ -199,6 +208,7 @@
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
@ -230,6 +240,7 @@
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
@ -243,6 +254,7 @@
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
@ -276,6 +288,7 @@
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
@ -289,6 +302,7 @@
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K

View File

@ -785,6 +785,165 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemv_q5_K_8x4_q8_K(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
constexpr int qk = QK_K;
const int nb = n / qk;
constexpr int ncols_interleaved = 8;
constexpr int blocklen = 4;
assert(n % qk == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(nb);
UNUSED(ncols_interleaved);
UNUSED(blocklen);
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
constexpr int col_groups = ncols_interleaved / 4; // 0123 and 4567
const uint8x16_t m4b = vdupq_n_u8(0x0f);
const uint8x16_t mone = vdupq_n_u8(1);
const uint8x16_t mtwo = vdupq_n_u8(2);
// 1x8 tile = 2 x 4
float32x4_t acc_f32[col_groups];
const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
for (int i = 0; i < col_groups; i++) {
acc_f32[i] = vdupq_n_f32(0);
}
for (int b = 0; b < nb; b++) {
float32x4_t q5_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d)); // d0 d1 d2 d3
float32x4_t q5_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4)); // d4 d5 d6 d7
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
float32x4_t sb_scale_0123 = vmulq_f32(q5_d_0, q8_d);
float32x4_t sb_scale_4567 = vmulq_f32(q5_d_1, q8_d);
float32x4_t q5_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin)); // dmin 0..3
float32x4_t q5_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4)); // dmin 4..7
float32x4_t sb_min_0123 = vmulq_f32(q5_dmin_0, q8_d);
float32x4_t sb_min_4567 = vmulq_f32(q5_dmin_1, q8_d);
// interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
int32x4_t acc_lo[col_groups];
int32x4_t acc_hi[col_groups];
// Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
int16_t bsums_arr[8];
vst1q_s16(bsums_arr, bsums);
uint8x16_t qh[col_groups][8];
for (int c = 0; c < col_groups; c++) {
for (int i = 0; i < 8; i++) {
qh[c][i] = vld1q_u8(q5_ptr[b].qh + i * 32 + 16 * c);
}
}
for (int sb = 0; sb < QK_K / 64; sb++) {
for (int i = 0; i < col_groups; i++) {
acc_lo[i] = vdupq_n_s32(0);
acc_hi[i] = vdupq_n_s32(0);
}
// Need scales for the low and high nibbles
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
int16x8_t q5sb_mins[2];
int16x8_t q5sb_scales[2];
for (int i = 0; i < 2; i++) {
int8_t aux_q5sb[8];
const int offset = sb * 24 + i * 12;
decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb);
q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb));
}
int8x16_t q8_qs[4];
for (int i = 0; i < 4; i++) {
q8_qs[i] = vld1q_s8(q8_ptr[b].qs + sb * 64 + i * 16);
}
for (int c = 0; c < col_groups; c++) {
uint8x16_t q5_cols[8];
uint8x16_t hbit_lo[8];
uint8x16_t hbit_hi[8];
int8x16_t q5_lo[8];
int8x16_t q5_hi[8];
for (int i = 0; i < 8; i++) {
q5_cols[i] = vld1q_u8(q5_ptr[b].qs + sb * QK_K + i * 32 + 16 * c);
hbit_lo[i] = vandq_u8(qh[c][i], mone);
hbit_hi[i] = vshlq_n_u8(vandq_u8(qh[c][i], mtwo), 3);
qh[c][i] = vshrq_n_u8(qh[c][i], 2);
q5_lo[i] = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_cols[i], m4b), hbit_lo[i], 4));
q5_hi[i] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_cols[i], 4), hbit_hi[i]));
}
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[0], q8_qs[0], 0);
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[1], q8_qs[0], 1);
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[2], q8_qs[0], 2);
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[3], q8_qs[0], 3);
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[4], q8_qs[1], 0);
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[5], q8_qs[1], 1);
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[6], q8_qs[1], 2);
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[7], q8_qs[1], 3);
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[0], q8_qs[2], 0);
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[1], q8_qs[2], 1);
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[2], q8_qs[2], 2);
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[3], q8_qs[2], 3);
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[4], q8_qs[3], 0);
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[5], q8_qs[3], 1);
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[6], q8_qs[3], 2);
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[7], q8_qs[3], 3);
}
// Scales
// row c0123 blk0 and blk1
const int16x4_t sc_0123_lo = vget_low_s16(q5sb_scales[0]);
const int16x4_t sc_0123_hi = vget_low_s16(q5sb_scales[1]);
const float32x4_t sumf_0123 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[0]),
vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[0])));
acc_f32[0] = vfmaq_f32(acc_f32[0], sb_scale_0123, sumf_0123);
// row c4567 blk0 and blk1
const int16x4_t sc_4567_lo = vget_high_s16(q5sb_scales[0]);
const int16x4_t sc_4567_hi = vget_high_s16(q5sb_scales[1]);
const float32x4_t sumf_4567 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[1]),
vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[1])));
acc_f32[1] = vfmaq_f32(acc_f32[1], sb_scale_4567, sumf_4567);
// Bias Correction
const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q5sb_mins[0]));
bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q5sb_mins[1]));
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q5sb_mins[0]));
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q5sb_mins[1]));
} // for sb
acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0123);
acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_4567);
} // for b
int base = x * ncols_interleaved;
vst1q_f32(s + base, acc_f32[0]);
vst1q_f32(s + base + 4, acc_f32[1]);
} // for x
return;
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
ggml_gemv_q5_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemv_q5_K_8x8_q8_K(int n,
float * GGML_RESTRICT s,
size_t bs,
@ -3205,6 +3364,235 @@ void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
ggml_gemm_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemm_q5_K_8x4_q8_K(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
constexpr int qk = QK_K;
const int nb = n / qk;
constexpr int ncols_interleaved = 8;
constexpr int blocklen = 4;
assert(n % qk == 0);
assert(nr % 4 == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(nb);
UNUSED(ncols_interleaved);
UNUSED(blocklen);
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
constexpr int q8_k_blocklen = 4;
constexpr int acc_size = 2 * 4; // 2 row pairs, 4 col pairs
constexpr int col_groups = ncols_interleaved / 4;
const uint8x16_t m4b = vdupq_n_u8(0x0f);
const uint8x16_t mone = vdupq_n_u8(1);
const uint8x16_t mtwo = vdupq_n_u8(2);
// 8 accumulators: 2 row pairs, 4 col pairs
float32x4_t acc_f32[acc_size];
for (int y = 0; y < nr / q8_k_blocklen; y++) {
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
for (int i = 0; i < acc_size; i++) {
acc_f32[i] = vdupq_n_f32(0);
}
for (int b = 0; b < nb; b++) {
// d5 0 1 2 3, 4 5 6 7
float32x4_t q5_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d));
float32x4_t q5_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4));
// d8 0 1 2 3
float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);
// mins
float32x4_t q5_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin));
float32x4_t q5_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4));
// Precomputation of scales and mins
float32x4_t sbd_scale_0123[q8_k_blocklen];
float32x4_t sbd_scale_4567[q8_k_blocklen];
float32x4_t sbd_min_0123[q8_k_blocklen];
float32x4_t sbd_min_4567[q8_k_blocklen];
sbd_scale_0123[0] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 0);
sbd_scale_4567[0] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 0);
sbd_min_0123[0] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 0);
sbd_min_4567[0] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 0);
sbd_scale_0123[1] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 1);
sbd_scale_4567[1] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 1);
sbd_min_0123[1] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 1);
sbd_min_4567[1] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 1);
sbd_scale_0123[2] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 2);
sbd_scale_4567[2] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 2);
sbd_min_0123[2] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 2);
sbd_min_4567[2] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 2);
sbd_scale_0123[3] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 3);
sbd_scale_4567[3] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 3);
sbd_min_0123[3] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 3);
sbd_min_4567[3] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 3);
// Precomputation of bsums, each vpaddq calcs all the bsums for each row
const int16x8_t bsums[q8_k_blocklen] = {
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
};
int16_t bsums_arr[QK_K / 64][8];
for (int q8_row = 0; q8_row < 4; q8_row++) {
vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
}
// interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 ..
int32x4_t bias_acc[acc_size];
for (int i = 0; i < acc_size; i++) {
bias_acc[i] = vdupq_n_s32(0);
}
uint8x16_t qh[col_groups][8];
for (int c = 0; c < col_groups; c++) {
for (int i = 0; i < 8; i++) {
qh[c][i] = vld1q_u8(q5_ptr[b].qh + i * 32 + 16 * c);
}
}
for (int sb = 0; sb < QK_K / 64; sb++) {
// Int accumulators for qs vecdot (4 row * 2 col quartets)
int32x4_t acc_lo[acc_size];
int32x4_t acc_hi[acc_size];
for (int i = 0; i < acc_size; i++) {
acc_lo[i] = vdupq_n_s32(0);
acc_hi[i] = vdupq_n_s32(0);
}
// Need scales for the low and high nibbles
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
int16x8_t q5sb_scales[2];
int16x8_t q5sb_mins[2];
for (int i = 0; i < 2; i++) {
int8_t aux_q5sb[8];
const int offset = sb * 24 + i * 12;
decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb);
q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb));
}
constexpr int reads_per_sb = 8; // 8 * 16 bytes each => 32 qs * 4 rows
for (int k = 0; k < reads_per_sb; k++) {
const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k);
const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128);
// 0..3 & 32..35
const uint8x16_t q5_0123 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 32 * k);
const uint8x16_t q5_4567 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 32 * k + 16);
// NOTE: This is the only difference with q4_K
const uint8x16_t hbit_lo_0123 = vandq_u8(qh[0][k], mone);
const uint8x16_t hbit_hi_0123 = vshlq_n_u8(vandq_u8(qh[0][k], mtwo), 3);
qh[0][k] = vshrq_n_u8(qh[0][k], 2);
const uint8x16_t hbit_lo_4567 = vandq_u8(qh[1][k], mone);
const uint8x16_t hbit_hi_4567 = vshlq_n_u8(vandq_u8(qh[1][k], mtwo), 3);
qh[1][k] = vshrq_n_u8(qh[1][k], 2);
// From here, same as q4_K
const int8x16_t q5_0123_lo =
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_0123, m4b), hbit_lo_0123, 4));
const int8x16_t q5_0123_hi =
vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_0123, 4), hbit_hi_0123));
acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q5_0123_lo, q8_blk0, 0); // 0..3 r0 c0123
acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q5_0123_lo, q8_blk0, 1); // 0..3 r1 c0123
acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q5_0123_lo, q8_blk0, 2); // 0..3 r2 c0123
acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q5_0123_lo, q8_blk0, 3); // 0..3 r3 c0123
acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q5_0123_hi, q8_blk1, 0); // 32..35 r0 c0123
acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q5_0123_hi, q8_blk1, 1); // 32..35 r1 c0123
acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q5_0123_hi, q8_blk1, 2); // 32..35 r2 c0123
acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q5_0123_hi, q8_blk1, 3); // 32..35 r3 c0123
const int8x16_t q5_4567_lo =
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_4567, m4b), hbit_lo_4567, 4));
const int8x16_t q5_4567_hi =
vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_4567, 4), hbit_hi_4567));
acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q5_4567_lo, q8_blk0, 0); // 0..3 r0 c4567
acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q5_4567_lo, q8_blk0, 1); // 0..3 r1 c4567
acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q5_4567_lo, q8_blk0, 2); // 0..3 r2 c4567
acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q5_4567_lo, q8_blk0, 3); // 0..3 r3 c4567
acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q5_4567_hi, q8_blk1, 0); // 32..35 r0 c4567
acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q5_4567_hi, q8_blk1, 1); // 32..35 r1 c4567
acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q5_4567_hi, q8_blk1, 2); // 32..35 r2 c4567
acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q5_4567_hi, q8_blk1, 3); // 32..35 r3 c4567
}
// Scale and bias application
// acc is stored interleaved to match output layout
const int16x4_t sc_0123_lo = vget_low_s16(q5sb_scales[0]);
const int16x4_t sc_4567_lo = vget_high_s16(q5sb_scales[0]);
const int16x4_t sc_0123_hi = vget_low_s16(q5sb_scales[1]);
const int16x4_t sc_4567_hi = vget_high_s16(q5sb_scales[1]);
for (int row = 0; row < q8_k_blocklen; row++) {
// Bias correction
// row c0123 blk0 and blk1
const float32x4_t sumf_0123 =
vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]),
vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row])));
acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123);
// row c4567 blk0 and blk1
const float32x4_t sumf_4567 =
vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]),
vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4])));
acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567);
// Bias
const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]);
const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]);
// row c0123 blk0 and blk1
bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q5sb_mins[0]));
bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q5sb_mins[1]));
// row c4567 blk0 and blk1
bias_acc[2 * row + 1] =
vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q5sb_mins[0]));
bias_acc[2 * row + 1] =
vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q5sb_mins[1]));
}
} // for sb
for (int row = 0; row < q8_k_blocklen; row++) {
acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]);
acc_f32[2 * row + 1] =
vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]);
}
} // for b
for (int i = 0; i < q8_k_blocklen; i++) {
int row = y * q8_k_blocklen + i;
for (int j = 0; j < 2; j++) {
int col = x * ncols_interleaved + j * 4;
int offset = row * bs + col;
vst1q_f32(s + offset, acc_f32[2 * i + j]);
}
}
} // for x
} // for y
return;
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
ggml_gemm_q5_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemm_q4_K_8x8_q8_K(int n,
float * GGML_RESTRICT s,
size_t bs,

View File

@ -450,6 +450,208 @@ static void ggml_gemm_q6_K_NxM_q8_K_generic_impl(int n,
}
}
template <int M, int N>
static void ggml_gemv_q5_K_NxM_q8_K_generic_impl(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
constexpr int blocklen = M;
constexpr int ncols_interleaved = N;
const int qk = QK_K;
const int nb = n / qk;
static const uint32_t kmask1 = 0x3f3f3f3f;
static const uint32_t kmask2 = 0x0f0f0f0f;
static const uint32_t kmask3 = 0x03030303;
assert(n % qk == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(bs);
UNUSED(nr);
float sumf[ncols_interleaved];
float sum_minf[ncols_interleaved];
uint32_t utmp[32];
int sumi1;
int sumi2;
int sumi;
const block_q8_K * a_ptr = (const block_q8_K *) vy;
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
for (int j = 0; j < ncols_interleaved; j++) {
sumf[j] = 0.0;
sum_minf[j] = 0.0;
}
for (int l = 0; l < nb; l++) {
for (int sb = 0; sb < 8; sb++) {
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * K_SCALE_SIZE, K_SCALE_SIZE);
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
utmp[sb * 4 + 2] = uaux_0;
utmp[sb * 4 + 0] &= kmask1;
}
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
constexpr int scale_stride = 32;
uint8_t * scales_0 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride;
uint8_t * scales_1 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride + 16;
const int qh_shift = (k / (32 / blocklen)) * 2;
for (int j = 0; j < ncols_interleaved; j++) {
sumi1 = 0;
sumi2 = 0;
sumi = 0;
for (int i = 0; i < blocklen; ++i) {
const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
const int qh_idx = (k * blocklen + i) % 32;
const int qh_chunk = qh_idx / blocklen;
const int qh_pos = qh_idx % blocklen;
const int b_qh_offset = qh_chunk * (blocklen * ncols_interleaved) + j * blocklen + qh_pos;
const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
const uint8_t h0 = (qh_val >> qh_shift) & 1;
const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1;
const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
const int q8_offset = (k / (32 / blocklen)) * 64 + (k % (32 / blocklen)) * blocklen + i;
sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
sumi2 = (v1 * a_ptr[l].qs[q8_offset + 32]);
sumi1 = sumi1 * scales_0[j];
sumi2 = sumi2 * scales_1[j];
sumi += sumi1 + sumi2;
}
sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
}
}
for (int sb = 0; sb < 8; sb++) {
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
for (int j = 0; j < ncols_interleaved; j++) {
sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) *
GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
}
}
}
for (int j = 0; j < ncols_interleaved; j++) {
s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
}
}
}
template <int M, int N>
static void ggml_gemm_q5_K_NxM_q8_K_generic_impl(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
constexpr int blocklen = M;
constexpr int ncols_interleaved = N;
const int qk = QK_K;
const int nb = n / qk;
static const uint32_t kmask1 = 0x3f3f3f3f;
static const uint32_t kmask2 = 0x0f0f0f0f;
static const uint32_t kmask3 = 0x03030303;
assert(n % qk == 0);
assert(nr % 4 == 0);
assert(nc % ncols_interleaved == 0);
float sumf[4][ncols_interleaved];
float sum_minf[4][ncols_interleaved];
uint32_t utmp[32];
int sumi1;
int sumi2;
int sumi;
for (int y = 0; y < nr / 4; y++) {
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
sumf[m][j] = 0.0;
sum_minf[m][j] = 0.0;
}
}
for (int l = 0; l < nb; l++) {
for (int sb = 0; sb < 8; sb++) {
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * K_SCALE_SIZE, K_SCALE_SIZE);
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
utmp[sb * 4 + 2] = uaux_0;
utmp[sb * 4 + 0] &= kmask1;
}
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
constexpr int scale_stride = 32;
uint8_t * scales_0 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride;
uint8_t * scales_1 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride + 16;
const int qh_shift = (k / (32 / blocklen)) * 2;
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
sumi1 = 0;
sumi2 = 0;
sumi = 0;
for (int i = 0; i < blocklen; ++i) {
const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
const int qh_idx = (k * blocklen + i) % 32;
const int qh_chunk = qh_idx / blocklen;
const int qh_pos = qh_idx % blocklen;
const int b_qh_offset =
qh_chunk * (blocklen * ncols_interleaved) + j * blocklen + qh_pos;
const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
const uint8_t h0 = (qh_val >> qh_shift) & 1;
const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1;
const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
const int q8_offset = (k / (32 / blocklen)) * 256 +
(k % (32 / blocklen)) * 4 * blocklen + m * blocklen + i;
sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
sumi2 = (v1 * a_ptr[l].qs[q8_offset + 128]);
sumi1 = sumi1 * scales_0[j];
sumi2 = sumi2 * scales_1[j];
sumi += sumi1 + sumi2;
}
sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
}
}
}
for (int sb = 0; sb < 8; sb++) {
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
for (int m = 0; m < 4; m++) {
const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
for (int j = 0; j < ncols_interleaved; j++) {
sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) *
GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
}
}
}
}
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
}
}
}
}
}
extern "C" {
void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@ -803,98 +1005,12 @@ void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
}
}
void ggml_gemv_q5_K_8x8_q8_K_generic(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
const int qk = QK_K;
const int nb = n / qk;
const int ncols_interleaved = 8;
const int blocklen = 8;
static const uint32_t kmask1 = 0x3f3f3f3f;
static const uint32_t kmask2 = 0x0f0f0f0f;
static const uint32_t kmask3 = 0x03030303;
void ggml_gemv_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
ggml_gemv_q5_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc);
}
assert(n % qk == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(bs);
UNUSED(nr);
float sumf[8];
float sum_minf[8];
uint32_t utmp[32];
int sumi1;
int sumi2;
int sumi;
const block_q8_K * a_ptr = (const block_q8_K *) vy;
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
for (int j = 0; j < ncols_interleaved; j++) {
sumf[j] = 0.0;
sum_minf[j] = 0.0;
}
for (int l = 0; l < nb; l++) {
for (int sb = 0; sb < 8; sb++) {
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
utmp[sb * 4 + 2] = uaux_0;
utmp[sb * 4 + 0] &= kmask1;
}
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32;
uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16;
const int qh_shift = (k / 4) * 2;
for (int j = 0; j < ncols_interleaved; j++) {
sumi1 = 0;
sumi2 = 0;
sumi = 0;
for (int i = 0; i < blocklen; ++i) {
const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
const int qh_idx = (k * 8 + i) % 32;
const int qh_chunk = qh_idx / 8;
const int qh_pos = qh_idx % 8;
const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos;
const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
const uint8_t h0 = (qh_val >> qh_shift) & 1;
const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1;
const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
const int q8_offset = (k >> 2) * 64 + (k % 4) * blocklen + i;
sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
sumi2 = (v1 * a_ptr[l].qs[q8_offset + 32]);
sumi1 = sumi1 * scales_0[j];
sumi2 = sumi2 * scales_1[j];
sumi += sumi1 + sumi2;
}
sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
}
}
for (int sb = 0; sb < 8; sb++) {
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
for (int j = 0; j < ncols_interleaved; j++) {
sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) *
GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
}
}
}
for (int j = 0; j < ncols_interleaved; j++) {
s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
}
}
void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
ggml_gemv_q5_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc);
}
@ -1494,107 +1610,12 @@ void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
}
}
void ggml_gemm_q5_K_8x8_q8_K_generic(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
const int qk = QK_K;
const int nb = n / qk;
const int ncols_interleaved = 8;
const int blocklen = 8;
void ggml_gemm_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
ggml_gemm_q5_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc);
}
constexpr uint32_t kmask1 = 0x3f3f3f3f;
constexpr uint32_t kmask2 = 0x0f0f0f0f;
constexpr uint32_t kmask3 = 0x03030303;
assert(n % qk == 0);
assert(nr % 4 == 0);
assert(nc % ncols_interleaved == 0);
float sumf[4][8];
float sum_minf[4][8];
uint32_t utmp[32];
int sumi1;
int sumi2;
int sumi;
for (int y = 0; y < nr / 4; y++) {
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
sumf[m][j] = 0.0;
sum_minf[m][j] = 0.0;
}
}
for (int l = 0; l < nb; l++) {
for (int sb = 0; sb < 8; sb++) {
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
utmp[sb * 4 + 2] = uaux_0;
utmp[sb * 4 + 0] &= kmask1;
}
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32;
uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16;
const int qh_shift = (k / 4) * 2;
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
sumi1 = 0;
sumi2 = 0;
sumi = 0;
for (int i = 0; i < blocklen; ++i) {
const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
const int qh_idx = (k * 8 + i) % 32;
const int qh_chunk = qh_idx / 8;
const int qh_pos = qh_idx % 8;
const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos;
const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
const uint8_t h0 = (qh_val >> qh_shift) & 1;
const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1;
const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
const int q8_offset = (k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i;
sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
sumi2 = (v1 * a_ptr[l].qs[q8_offset + 128]);
sumi1 = sumi1 * scales_0[j];
sumi2 = sumi2 * scales_1[j];
sumi += sumi1 + sumi2;
}
sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
}
}
}
for (int sb = 0; sb < 8; sb++) {
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
for (int m = 0; m < 4; m++) {
const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
for (int j = 0; j < ncols_interleaved; j++) {
sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) *
GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
}
}
}
}
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
}
}
}
}
void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
ggml_gemm_q5_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@ -2029,18 +2050,16 @@ static block_q5_Kx8 make_block_q5_Kx8(block_q5_K * in, unsigned int blck_size_in
const int end = QK_K * 4 / blck_size_interleave;
// Interleave Q5_K quants by taking 8 bytes at a time
// Interleave Q5_K quants by taking blck_size_interleave bytes at a time
for (int i = 0; i < end; ++i) {
int src_id = i % 8;
int src_offset = (i / 8) * blck_size_interleave;
int dst_offset = i * blck_size_interleave;
uint64_t elems;
memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], blck_size_interleave);
}
// Repeat for low bits 8 bytes at a time as well, since
// Repeat for high bits with the same chunk size, since
// the high bits are interleaved in Q5_K and the index is
// qh_idx = (qs_idx % 32);
// qh_val = qh[qh_idx] >> (qs_idx / 32);
@ -2049,9 +2068,7 @@ static block_q5_Kx8 make_block_q5_Kx8(block_q5_K * in, unsigned int blck_size_in
int src_offset = (i / 8) * blck_size_interleave;
int dst_offset = i * blck_size_interleave;
uint64_t elems;
memcpy(&elems, &in[src_id].qh[src_offset], sizeof(uint64_t));
memcpy(&out.qh[dst_offset], &elems, sizeof(uint64_t));
memcpy(&out.qh[dst_offset], &in[src_id].qh[src_offset], blck_size_interleave);
}
// The below logic is copied over from Q4_K
@ -2249,7 +2266,7 @@ static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor * t,
const void * GGML_RESTRICT data,
size_t data_size) {
GGML_ASSERT(t->type == GGML_TYPE_Q5_K);
GGML_ASSERT(interleave_block == 8);
GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
constexpr int nrows_interleaved = 8;
block_q5_Kx8 * dst = (block_q5_Kx8 *) t->data;
@ -2523,6 +2540,10 @@ template <> int repack<block_q2_K, 8, 8>(struct ggml_tensor * t, const void * da
return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size);
}
template <> int repack<block_q5_K, 4, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
return repack_q5_K_to_q5_K_8_bl(t, 4, data, data_size);
}
template <> int repack<block_q5_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
return repack_q5_K_to_q5_K_8_bl(t, 8, data, data_size);
}
@ -2591,6 +2612,10 @@ template <> void gemv<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t
ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
}
template <> void gemv<block_q5_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_q5_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
}
template <> void gemv<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
}
@ -2654,6 +2679,10 @@ template <> void gemm<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t
ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
}
template <> void gemm<block_q5_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_q5_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
}
template <> void gemm<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
}
@ -3068,6 +3097,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
// instance for Q5_K
static const ggml::cpu::repack::tensor_traits<block_q5_K, 4, 8, GGML_TYPE_Q8_K> q5_K_8x4_q8_K;
static const ggml::cpu::repack::tensor_traits<block_q5_K, 8, 8, GGML_TYPE_Q8_K> q5_K_8x8_q8_K;
// instance for Q6_K
@ -3130,6 +3160,11 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
return &q5_K_8x8_q8_K;
}
}
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
if (cur->ne[1] % 8 == 0) {
return &q5_K_8x4_q8_K;
}
}
} else if (cur->type == GGML_TYPE_Q6_K) {
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
if (cur->ne[1] % 8 == 0) {

View File

@ -111,6 +111,7 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@ -122,6 +123,7 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@ -143,6 +145,7 @@ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@ -154,6 +157,7 @@ void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);

View File

@ -1149,8 +1149,7 @@ struct ggml_cuda_graph {
size_t num_nodes = 0;
std::vector<cudaGraphNode_t> nodes;
bool disable_due_to_gpu_arch = false;
bool disable_due_to_too_many_updates = false;
int number_consecutive_updates = 0;
bool warmup_complete = false;
std::vector<ggml_cuda_graph_node_properties> props;
// these are extra tensors (inputs) that participate in the ggml graph but are not nodes
@ -1159,21 +1158,9 @@ struct ggml_cuda_graph {
// ref: https://github.com/ggml-org/llama.cpp/pull/19165
std::vector<ggml_cuda_graph_node_properties> extra;
void record_update(bool use_graph, bool update_required) {
if (use_graph && update_required) {
number_consecutive_updates++;
} else {
number_consecutive_updates = 0;
}
if (number_consecutive_updates >= 4) {
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
disable_due_to_too_many_updates = true;
}
}
bool is_enabled() const {
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env || disable_due_to_too_many_updates);
return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env);
}
#endif
};

View File

@ -2979,10 +2979,6 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
const void * graph_key = ggml_cuda_graph_get_key(cgraph);
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
if (graph->instance == nullptr) {
res = true;
}
// Check if the graph size has changed
if (graph->props.size() != (size_t)cgraph->n_nodes) {
res = true;
@ -3931,14 +3927,35 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
#ifdef USE_CUDA_GRAPH
graph_key = ggml_cuda_graph_get_key(cgraph);
use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
if (graph->is_enabled()) {
cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph);
use_cuda_graph = ggml_cuda_graph_check_compability(cgraph);
const bool graph_compatible = ggml_cuda_graph_check_compability(cgraph);
if (graph_compatible) {
const bool properties_changed = ggml_cuda_graph_update_required(cuda_ctx, cgraph);
graph->record_update(use_cuda_graph, cuda_graph_update_required);
if (!graph->warmup_complete) {
// Warmup: need at least 2 calls with no property change on the 2nd call
if (!properties_changed) {
graph->warmup_complete = true;
GGML_LOG_DEBUG("%s: CUDA graph warmup complete\n", __func__);
use_cuda_graph = true;
cuda_graph_update_required = true;
}
// else: properties changed or first call - execute directly (use_cuda_graph stays false)
} else {
// Post-warmup: normal CUDA graph operation
if (properties_changed) {
// Properties changed - reset warmup, execute directly until stable again
graph->warmup_complete = false;
GGML_LOG_DEBUG("%s: CUDA graph warmup reset\n", __func__);
} else {
use_cuda_graph = true;
cuda_graph_update_required = graph->instance == nullptr;
}
}
}
}
#endif // USE_CUDA_GRAPH

View File

@ -1749,23 +1749,6 @@ static inline bool ggml_backend_buffer_is_hexagon_repack(const struct ggml_backe
return b->buft->iface.alloc_buffer == ggml_backend_hexagon_repack_buffer_type_alloc_buffer;
}
static bool hex_supported_dims2(const struct ggml_tensor * x, const struct ggml_tensor * y) {
if (x->ne[0] != y->ne[0]) {
return false;
}
if (x->ne[1] != y->ne[1]) {
return false;
}
if (x->ne[2] != y->ne[2]) {
return false;
}
if (x->ne[3] != y->ne[3]) {
return false;
}
return true;
}
static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
const struct ggml_tensor * src0 = op->src[0];
const struct ggml_tensor * src1 = op->src[1];
@ -1797,43 +1780,6 @@ static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_sess
return opt_experimental;
}
static bool hex_supported_src0_type(ggml_type t) {
return t == GGML_TYPE_F32;
}
static bool hex_supported_src1_type(ggml_type t) {
return t == GGML_TYPE_F32;
}
static bool hex_supported_src2_type(ggml_type t) {
return t == GGML_TYPE_F32;
}
static bool hex_supported_src1_type2(ggml_type t) {
return t == GGML_TYPE_F16;
}
static bool hex_supported_src1_type3(ggml_type t) {
return t == GGML_TYPE_I32;
}
static bool hex_supported_dst_type(ggml_type t) {
return t == GGML_TYPE_F32;
}
static bool hex_supported_dims(const struct ggml_tensor * x, const struct ggml_tensor * y) {
// TODO: support broadcast for ne[2 and 3]
if (x->ne[0] != y->ne[0]) {
return false;
}
if (x->ne[2] != y->ne[2]) {
return false;
}
if (x->ne[3] != y->ne[3]) {
return false;
}
return true;
}
static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
@ -1919,19 +1865,19 @@ static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * se
const struct ggml_tensor * src1 = op->src[1];
const struct ggml_tensor * dst = op;
if (!hex_supported_src0_type(src0->type)) {
if (src0->type != GGML_TYPE_F32) {
return false;
}
if (!hex_supported_src1_type(src1->type)) {
if (src1->type != GGML_TYPE_F32) {
return false;
}
if (!hex_supported_dst_type(dst->type)) {
if (dst->type != GGML_TYPE_F32) {
return false;
}
if (!hex_supported_dims2(src0, dst)) {
if (!ggml_are_same_shape(src0, dst)) {
return false;
}
if (!ggml_can_repeat(src1, src0)) {
if (!ggml_can_repeat(src1, src0) || ggml_is_permuted(src1)) {
return false;
}
@ -1943,16 +1889,16 @@ static bool ggml_hexagon_supported_add_id(const struct ggml_hexagon_session * se
const struct ggml_tensor * src1 = op->src[1];
const struct ggml_tensor * dst = op;
if (!hex_supported_src0_type(src0->type)) {
if (src0->type != GGML_TYPE_F32) {
return false;
}
if (!hex_supported_src1_type(src1->type)) {
if (src1->type != GGML_TYPE_F32) {
return false;
}
if (!hex_supported_dst_type(dst->type)) {
if (dst->type != GGML_TYPE_F32) {
return false;
}
if (!hex_supported_dims2(src0, dst)) {
if (!ggml_are_same_shape(src0, dst)) {
return false;
}
@ -1968,13 +1914,13 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses
const struct ggml_tensor * src0 = op->src[0];
const struct ggml_tensor * dst = op;
if (!hex_supported_src0_type(src0->type)) {
if (src0->type != GGML_TYPE_F32) {
return false;
}
if (!hex_supported_dst_type(dst->type)) {
if (dst->type != GGML_TYPE_F32) {
return false;
}
if (!hex_supported_dims2(src0, dst)) {
if (!ggml_are_same_shape(src0, dst)) {
return false;
}
@ -1990,10 +1936,10 @@ static bool ggml_hexagon_supported_sum_rows(const struct ggml_hexagon_session *
const struct ggml_tensor * src0 = op->src[0];
const struct ggml_tensor * dst = op;
if (!hex_supported_src0_type(src0->type)) {
if (src0->type != GGML_TYPE_F32) {
return false;
}
if (!hex_supported_dst_type(dst->type)) {
if (dst->type != GGML_TYPE_F32) {
return false;
}
@ -2011,10 +1957,10 @@ static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session
const struct ggml_tensor * src1 = op->src[1];
const struct ggml_tensor * dst = op;
if (!hex_supported_src0_type(src0->type)) {
if (src0->type != GGML_TYPE_F32) {
return false;
}
if (!hex_supported_dst_type(dst->type)) {
if (dst->type != GGML_TYPE_F32) {
return false;
}
@ -2023,10 +1969,10 @@ static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session
}
if (src1) {
if (!hex_supported_src1_type(src1->type)) {
if (src1->type != GGML_TYPE_F32) {
return false;
}
if (!hex_supported_dims2(src0, src1)) {
if (!ggml_are_same_shape(src0, src1)) {
return false;
}
if (!ggml_is_contiguous(src1)) {
@ -2047,15 +1993,15 @@ static bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * s
return false; // FIXME: add support for sinks
}
if (!hex_supported_src0_type(src0->type)) {
if (src0->type != GGML_TYPE_F32) {
return false;
}
if (!hex_supported_dst_type(dst->type)) {
if (dst->type != GGML_TYPE_F32) {
return false;
}
if (src1) {
if (!hex_supported_src1_type(src1->type) && !hex_supported_src1_type2(src1->type)) {
if (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) {
return false;
}
if (src0->ne[0] != src1->ne[0]) {
@ -2162,17 +2108,17 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess
const struct ggml_tensor * src2 = op->src[2];
const struct ggml_tensor * dst = op;
if (!hex_supported_src0_type(src0->type)) {
if (src0->type != GGML_TYPE_F32) {
return false; // FIXME: add support for GGML_TYPE_F16 for src0
}
if (!hex_supported_dst_type(dst->type)) {
if (dst->type != GGML_TYPE_F32) {
return false;
}
if (!hex_supported_src1_type3(src1->type)) {
if (src1->type != GGML_TYPE_I32) {
return false;
}
if (src2) {
if (!hex_supported_src2_type(src2->type)) {
if (src2->type != GGML_TYPE_F32) {
return false;
}
int n_dims = op_params[1];

View File

@ -69,27 +69,45 @@
const uint32_t nb2 = dst->nb[2]; \
const uint32_t nb3 = dst->nb[3];
static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0,
const struct htp_tensor * src1,
struct htp_tensor * dst,
const int32_t * op_params,
struct htp_spad * src0_spad,
struct htp_spad * src1_spad,
struct htp_spad * dst_spad,
uint32_t nth,
uint32_t ith,
uint32_t src0_nrows_per_thread,
dma_queue * dma_queue) {
struct htp_act_context {
struct htp_ops_context * octx;
// Precomputed values
const uint8_t * data_src0;
const uint8_t * data_src1;
uint8_t * data_dst;
size_t src0_row_size;
size_t src1_row_size;
size_t dst_row_size;
size_t src0_row_size_aligned;
size_t src1_row_size_aligned;
size_t dst_row_size_aligned;
size_t src0_spad_half_size;
size_t src1_spad_half_size;
size_t dst_spad_half_size;
uint32_t block;
uint32_t src0_nrows;
uint32_t src0_nrows_per_thread;
int nc;
};
static void glu_swiglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
struct htp_act_context * actx = (struct htp_act_context *) data;
const struct htp_tensor * src0 = &actx->octx->src0;
const struct htp_tensor * src1 = &actx->octx->src1;
const struct htp_tensor * dst = &actx->octx->dst;
htp_act_preamble3;
size_t src0_row_size = nb01;
size_t src1_row_size = nb11;
size_t dst_row_size = nb1;
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
size_t src0_row_size = actx->src0_row_size;
size_t src1_row_size = actx->src1_row_size;
size_t dst_row_size = actx->dst_row_size;
const uint32_t src0_nrows = actx->src0_nrows;
const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
@ -101,43 +119,34 @@ static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0,
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
uint8_t * restrict data_dst = (uint8_t *) dst->data;
const uint8_t * restrict data_src0 = actx->data_src0;
const uint8_t * restrict data_src1 = actx->data_src1;
uint8_t * restrict data_dst = actx->data_dst;
const bool src1_valid = src1->ne[0];
const int nc = (src1_valid) ? ne00 : ne00 / 2;
if (!src1_valid) {
const int32_t swapped = op_params[1];
data_src1 = data_src0;
src1_row_size = src0_row_size;
const int nc = actx->nc;
const size_t nc_in_bytes = nc * SIZEOF_FP32;
data_src0 += swapped ? nc_in_bytes : 0;
data_src1 += swapped ? 0 : nc_in_bytes;
}
const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
const size_t src1_row_size_aligned = actx->src1_row_size_aligned;
const size_t dst_row_size_aligned = actx->dst_row_size_aligned;
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread);
uint8_t * restrict dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
size_t src0_spad_half_size = actx->src0_spad_half_size;
size_t src1_spad_half_size = actx->src1_spad_half_size;
size_t dst_spad_half_size = actx->dst_spad_half_size;
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
const int BLOCK = actx->block;
if (BLOCK == 0) {
FARF(ERROR,
"swiglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
src0_spad->size_per_thread, src0_row_size_aligned);
actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
return;
}
dma_queue * dma_queue = actx->octx->ctx->dma[ith];
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
@ -196,27 +205,22 @@ static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0,
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
static void glu_swiglu_oai_f32_per_thread(const struct htp_tensor * src0,
const struct htp_tensor * src1,
struct htp_tensor * dst,
const int32_t * op_params,
struct htp_spad * src0_spad,
struct htp_spad * src1_spad,
struct htp_spad * dst_spad,
uint32_t nth,
uint32_t ith,
uint32_t src0_nrows_per_thread,
dma_queue * dma_queue) {
static void glu_swiglu_oai_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
struct htp_act_context * actx = (struct htp_act_context *) data;
const struct htp_tensor * src0 = &actx->octx->src0;
const struct htp_tensor * src1 = &actx->octx->src1;
const struct htp_tensor * dst = &actx->octx->dst;
htp_act_preamble3;
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
size_t src0_row_size = nb01;
size_t src1_row_size = nb11;
size_t dst_row_size = nb1;
size_t src0_row_size = actx->src0_row_size;
size_t src1_row_size = actx->src1_row_size;
size_t dst_row_size = actx->dst_row_size;
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
const uint32_t src0_nrows = actx->src0_nrows;
const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
@ -226,45 +230,36 @@ static void glu_swiglu_oai_f32_per_thread(const struct htp_tensor * src0,
return;
}
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
uint8_t * restrict data_dst = (uint8_t *) dst->data;
const uint8_t * restrict data_src0 = actx->data_src0;
const uint8_t * restrict data_src1 = actx->data_src1;
uint8_t * restrict data_dst = actx->data_dst;
const bool src1_valid = src1->ne[0];
const int nc = (src1_valid) ? ne00 : ne00 / 2;
if (!src1_valid) {
const int32_t swapped = op_params[1];
data_src1 = data_src0;
src1_row_size = src0_row_size;
const int nc = actx->nc;
const size_t nc_in_bytes = nc * SIZEOF_FP32;
data_src0 += swapped ? nc_in_bytes : 0;
data_src1 += swapped ? 0 : nc_in_bytes;
}
const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
const size_t src1_row_size_aligned = actx->src1_row_size_aligned;
const size_t dst_row_size_aligned = actx->dst_row_size_aligned;
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread);
uint8_t * restrict dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
size_t src0_spad_half_size = actx->src0_spad_half_size;
size_t src1_spad_half_size = actx->src1_spad_half_size;
size_t dst_spad_half_size = actx->dst_spad_half_size;
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
const int BLOCK = actx->block;
if (BLOCK == 0) {
FARF(ERROR,
"swiglu-oai-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least "
"%zu\n",
src0_spad->size_per_thread, src0_row_size_aligned);
actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
return;
}
const float alpha = ((const float *) (op_params))[2];
const float limit = ((const float *) (op_params))[3];
const float alpha = ((const float *) (actx->octx->op_params))[2];
const float limit = ((const float *) (actx->octx->op_params))[3];
dma_queue * dma_queue = actx->octx->ctx->dma[ith];
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
@ -335,26 +330,22 @@ static void glu_swiglu_oai_f32_per_thread(const struct htp_tensor * src0,
}
static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
struct htp_tensor * dst,
const int32_t * op_params,
struct htp_spad * src0_spad,
struct htp_spad * dst_spad,
uint32_t nth,
uint32_t ith,
uint32_t src0_nrows_per_thread,
dma_queue * dma_queue) {
static void unary_gelu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
struct htp_act_context * actx = (struct htp_act_context *) data;
const struct htp_tensor * src0 = &actx->octx->src0;
const struct htp_tensor * dst = &actx->octx->dst;
htp_act_preamble2;
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
const size_t src0_row_size = nb01;
const size_t dst_row_size = nb1;
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
const size_t src0_row_size = actx->src0_row_size;
const size_t dst_row_size = actx->dst_row_size;
const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
const size_t dst_row_size_aligned = actx->dst_row_size_aligned;
const uint32_t src0_nrows = ne01 * ne02 * ne03;
const uint32_t src0_nrows = actx->src0_nrows;
const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
@ -364,25 +355,29 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
return;
}
const uint8_t * data_src0 = (const uint8_t *) src0->data;
uint8_t * data_dst = (uint8_t *) dst->data;
const uint8_t * data_src0 = actx->data_src0;
uint8_t * data_dst = actx->data_dst;
uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
uint8_t * dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
// nc/ne0 matches.
const int ne0_val = actx->nc; // == dst->ne[0]
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
uint8_t * src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
uint8_t * dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
size_t src0_spad_half_size = actx->src0_spad_half_size;
size_t dst_spad_half_size = actx->dst_spad_half_size;
// In gelu = x*sigmoid(x*1.702)
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
const int BLOCK = actx->block;
if (BLOCK == 0) {
FARF(ERROR, "gelu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
src0_spad->size_per_thread, src0_row_size_aligned);
actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
return;
}
dma_queue * dma_queue = actx->octx->ctx->dma[ith];
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
@ -408,9 +403,9 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
// gelu = x * sigmoid(1.702 * x) // current implementation
hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0);
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0_val);
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val);
hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val);
}
dma_queue_push_vtcm_to_ddr(dma_queue,
@ -435,34 +430,23 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
static void unary_gelu_f32(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = (struct htp_ops_context *) data;
unary_gelu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
octx->src0_nrows_per_thread, octx->ctx->dma[i]);
}
static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
struct htp_tensor * dst,
const int32_t * op_params,
struct htp_spad * src0_spad,
struct htp_spad * dst_spad,
uint32_t nth,
uint32_t ith,
uint32_t src0_nrows_per_thread,
dma_queue * dma_queue) {
static void unary_silu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
struct htp_act_context * actx = (struct htp_act_context *) data;
const struct htp_tensor * src0 = &actx->octx->src0;
const struct htp_tensor * dst = &actx->octx->dst;
htp_act_preamble2;
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
const size_t src0_row_size = nb01;
const size_t dst_row_size = nb1;
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
const size_t src0_row_size = actx->src0_row_size;
const size_t dst_row_size = actx->dst_row_size;
const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
const size_t dst_row_size_aligned = actx->dst_row_size_aligned;
const uint32_t src0_nrows = ne01 * ne02 * ne03;
const uint32_t src0_nrows = actx->src0_nrows;
const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
@ -472,24 +456,27 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
return;
}
const uint8_t * data_src0 = (const uint8_t *) src0->data;
uint8_t * data_dst = (uint8_t *) dst->data;
const uint8_t * data_src0 = actx->data_src0;
uint8_t * data_dst = actx->data_dst;
uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
uint8_t * dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
const int ne0_val = actx->nc; // == dst->ne[0]
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
uint8_t * src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
uint8_t * dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
size_t src0_spad_half_size = actx->src0_spad_half_size;
size_t dst_spad_half_size = actx->dst_spad_half_size;
const int BLOCK = actx->block;
if (BLOCK == 0) {
FARF(ERROR, "silu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
src0_spad->size_per_thread, src0_row_size_aligned);
actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
return;
}
dma_queue * dma_queue = actx->octx->ctx->dma[ith];
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
@ -515,8 +502,8 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
// silu = x * sigmoid(x)
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0);
hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0_val);
hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val);
}
dma_queue_push_vtcm_to_ddr(dma_queue,
@ -544,27 +531,22 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
static const float GELU_COEF_A = 0.044715f;
static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
static void glu_geglu_f32_per_thread(const struct htp_tensor * src0,
const struct htp_tensor * src1,
struct htp_tensor * dst,
const int32_t * op_params,
struct htp_spad * src0_spad,
struct htp_spad * src1_spad,
struct htp_spad * dst_spad,
uint32_t nth,
uint32_t ith,
uint32_t src0_nrows_per_thread,
dma_queue * dma_queue) {
static void glu_geglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
struct htp_act_context * actx = (struct htp_act_context *) data;
const struct htp_tensor * src0 = &actx->octx->src0;
const struct htp_tensor * src1 = &actx->octx->src1;
const struct htp_tensor * dst = &actx->octx->dst;
htp_act_preamble3;
size_t src0_row_size = nb01;
size_t src1_row_size = nb11;
size_t dst_row_size = nb1;
size_t src0_row_size = actx->src0_row_size;
size_t src1_row_size = actx->src1_row_size;
size_t dst_row_size = actx->dst_row_size;
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
const uint32_t src0_nrows = actx->src0_nrows;
const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
@ -574,43 +556,34 @@ static void glu_geglu_f32_per_thread(const struct htp_tensor * src0,
return;
}
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
uint8_t * restrict data_dst = (uint8_t *) dst->data;
const uint8_t * restrict data_src0 = actx->data_src0;
const uint8_t * restrict data_src1 = actx->data_src1;
uint8_t * restrict data_dst = actx->data_dst;
const bool src1_valid = src1->ne[0];
const int nc = (src1_valid) ? ne00 : ne00 / 2;
if (!src1_valid) {
const int32_t swapped = op_params[1];
data_src1 = data_src0;
src1_row_size = src0_row_size;
const int nc = actx->nc;
const size_t nc_in_bytes = nc * SIZEOF_FP32;
data_src0 += swapped ? nc_in_bytes : 0;
data_src1 += swapped ? 0 : nc_in_bytes;
}
const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
const size_t src1_row_size_aligned = actx->src1_row_size_aligned;
const size_t dst_row_size_aligned = actx->dst_row_size_aligned;
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread);
uint8_t * restrict dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
size_t src0_spad_half_size = actx->src0_spad_half_size;
size_t src1_spad_half_size = actx->src1_spad_half_size;
size_t dst_spad_half_size = actx->dst_spad_half_size;
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
const int BLOCK = actx->block;
if (BLOCK == 0) {
FARF(ERROR,
"geglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
src0_spad->size_per_thread, src0_row_size_aligned);
actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
return;
}
dma_queue * dma_queue = actx->octx->ctx->dma[ith];
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
@ -678,33 +651,7 @@ static void glu_geglu_f32_per_thread(const struct htp_tensor * src0,
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
static void unary_silu_f32(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = (struct htp_ops_context *) data;
unary_silu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
octx->src0_nrows_per_thread, octx->ctx->dma[i]);
}
static void glu_swiglu_f32(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = (struct htp_ops_context *) data;
glu_swiglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
}
static void glu_swiglu_oai_f32(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = (struct htp_ops_context *) data;
glu_swiglu_oai_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
}
static void glu_geglu_f32(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = (struct htp_ops_context *) data;
glu_geglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
}
static int execute_op_activations_f32(struct htp_ops_context * octx) {
int err = HTP_STATUS_OK;
const struct htp_tensor * src0 = &octx->src0;
const struct htp_tensor * src1 = &octx->src1;
struct htp_tensor * dst = &octx->dst;
@ -719,26 +666,26 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
switch (octx->op) {
case HTP_OP_UNARY_SILU:
act_op_func = unary_silu_f32;
act_op_func = (worker_callback_t)unary_silu_f32_per_thread;
op_type = "silu-f32";
break;
case HTP_OP_GLU_SWIGLU:
act_op_func = glu_swiglu_f32;
act_op_func = (worker_callback_t)glu_swiglu_f32_per_thread;
op_type = "swiglu-f32";
break;
case HTP_OP_GLU_SWIGLU_OAI:
act_op_func = glu_swiglu_oai_f32;
act_op_func = (worker_callback_t)glu_swiglu_oai_f32_per_thread;
op_type = "swiglu-oai-f32";
break;
case HTP_OP_UNARY_GELU:
act_op_func = unary_gelu_f32;
act_op_func = (worker_callback_t)unary_gelu_f32_per_thread;
op_type = "gelu-f32";
break;
case HTP_OP_GLU_GEGLU:
act_op_func = glu_geglu_f32;
act_op_func = (worker_callback_t)glu_geglu_f32_per_thread;
op_type = "geglu-f32";
break;
default:
@ -797,13 +744,58 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
}
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
uint32_t n_jobs = MIN(n_threads, src0_nrows);
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
worker_pool_run_func(octx->ctx->worker_pool, act_op_func, octx, n_jobs);
if ((octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
return HTP_STATUS_OK;
}
return err;
uint32_t n_jobs = MIN(n_threads, src0_nrows);
// Prepare context
struct htp_act_context actx;
actx.octx = octx;
actx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
actx.src0_row_size = src0_row_size;
actx.src1_row_size = src1_row_size;
actx.dst_row_size = dst_row_size;
actx.src0_row_size_aligned = src0_row_size_aligned;
actx.src1_row_size_aligned = src1_row_size_aligned;
actx.dst_row_size_aligned = dst_row_size_aligned;
actx.src0_spad_half_size = octx->src0_spad.size_per_thread / 2;
actx.src1_spad_half_size = octx->src1_spad.size_per_thread / 2;
actx.dst_spad_half_size = octx->dst_spad.size_per_thread / 2;
actx.block = actx.src0_spad_half_size / actx.src0_row_size_aligned;
actx.src0_nrows = src0_nrows;
actx.nc = dst->ne[0];
// Pointers and GLU logic
const uint8_t * data_src0 = (const uint8_t *) src0->data;
const uint8_t * data_src1 = (const uint8_t *) src1->data;
if (!src1_valid && (octx->op == HTP_OP_GLU_SWIGLU || octx->op == HTP_OP_GLU_SWIGLU_OAI || octx->op == HTP_OP_GLU_GEGLU)) {
const int32_t swapped = octx->op_params[1];
data_src1 = data_src0;
actx.src1_row_size = actx.src0_row_size;
size_t nc_in_bytes = actx.nc * SIZEOF_FP32;
if (swapped) {
data_src0 += nc_in_bytes;
} else {
data_src1 += nc_in_bytes;
}
}
actx.data_src0 = data_src0;
actx.data_src1 = data_src1;
actx.data_dst = (uint8_t *) dst->data;
worker_pool_run_func(octx->ctx->worker_pool, act_op_func, &actx, n_jobs);
return HTP_STATUS_OK;
}
int op_activations(struct htp_ops_context * octx) {

View File

@ -15,6 +15,13 @@
#include "htp-ops.h"
#include "hvx-utils.h"
struct get_rows_context {
struct htp_ops_context * octx;
uint32_t src1_nrows_per_thread;
struct fastdiv_values get_rows_div_ne10;
struct fastdiv_values get_rows_div_ne10_ne11;
};
#define get_rows_preamble \
const uint32_t ne00 = octx->src0.ne[0]; \
const uint32_t ne01 = octx->src0.ne[1]; \
@ -39,20 +46,22 @@
\
const uint32_t nr = ne10 * ne11 * ne12;
static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) {
struct get_rows_context * grctx = (struct get_rows_context *)data;
struct htp_ops_context * octx = grctx->octx;
get_rows_preamble;
// parallelize by src1 elements (which correspond to dst rows)
const uint32_t dr = octx->src1_nrows_per_thread;
const uint32_t dr = grctx->src1_nrows_per_thread;
const uint32_t ir0 = dr * ith;
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
for (uint32_t i = ir0; i < ir1; ++i) {
const uint32_t i12 = fastdiv(i, &octx->get_rows_div_ne10_ne11);
const uint32_t i12 = fastdiv(i, &grctx->get_rows_div_ne10_ne11);
const uint32_t rem = i - i12 * ne11 * ne10;
const uint32_t i11 = fastdiv(rem, &octx->get_rows_div_ne10);
const uint32_t i11 = fastdiv(rem, &grctx->get_rows_div_ne10);
const uint32_t i10 = rem - i11 * ne10;
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
@ -68,12 +77,6 @@ static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth,
const uintptr_t dst_ptr = octx->dst.data + i10*nb1 + i11*nb2 + i12*nb3;
hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
}
return HTP_STATUS_OK;
}
static void get_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
get_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
}
int op_get_rows(struct htp_ops_context * octx) {
@ -95,12 +98,14 @@ int op_get_rows(struct htp_ops_context * octx) {
return HTP_STATUS_OK;
}
octx->get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]);
octx->get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]);
struct get_rows_context grctx;
grctx.octx = octx;
grctx.get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]);
grctx.get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]);
const uint32_t n_jobs = MIN(nr, octx->n_threads);
octx->src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
grctx.src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
worker_pool_run_func(octx->ctx->worker_pool, get_rows_work_f32_f32, octx, n_jobs);
worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32, &grctx, n_jobs);
return HTP_STATUS_OK;
}

View File

@ -102,7 +102,7 @@ static inline bool dma_queue_push(dma_queue * q,
dmlink(q->tail, desc);
q->tail = desc;
// FARF(ERROR, "dma-push: i %u len %u dst %p src %p\n", q->push_idx, len, dst, src);
// FARF(ERROR, "dma-push: i %u width %u nrows %d dst %p src %p\n", q->push_idx, width, nrows, dptr.dst, dptr.src);
q->push_idx = (q->push_idx + 1) & q->idx_mask;
return true;
}
@ -144,11 +144,37 @@ static inline dma_ptr dma_queue_pop(dma_queue * q) {
dptr = q->dptr[q->pop_idx];
// FARF(ERROR, "dma-pop: i %u dst %p\n", q->pop_idx, dst);
// FARF(ERROR, "dma-pop: i %u dst %p src %p\n", q->pop_idx, dptr.dst, dptr.src);
q->pop_idx = (q->pop_idx + 1) & q->idx_mask;
return dptr;
}
static inline dma_ptr dma_queue_pop_nowait(dma_queue * q) {
dma_ptr dptr = { NULL };
if (q->push_idx == q->pop_idx) {
return dptr;
}
dptr = q->dptr[q->pop_idx];
// FARF(ERROR, "dma-pop-nowait: i %u dst %p src %p\n", q->pop_idx, dptr.dst, dptr.src);
q->pop_idx = (q->pop_idx + 1) & q->idx_mask;
return dptr;
}
static inline bool dma_queue_empty(dma_queue * q) {
return q->push_idx == q->pop_idx;
}
static inline uint32_t dma_queue_depth(dma_queue * q) {
return (q->push_idx - q->pop_idx) & q->idx_mask;
}
static inline uint32_t dma_queue_capacity(dma_queue * q) {
return q->capacity;
}
#ifdef __cplusplus
} // extern "C"
#endif

View File

@ -44,32 +44,6 @@ struct htp_ops_context {
uint32_t src0_nrows_per_thread;
uint32_t src1_nrows_per_thread;
struct fastdiv_values src0_div1; // fastdiv values for ne1
struct fastdiv_values src0_div2; // fastdiv values for ne2
struct fastdiv_values src0_div3; // fastdiv values for ne3
struct fastdiv_values src0_div21; // fastdiv values for ne2 * ne1
struct fastdiv_values src1_div1; // fastdiv values for ne1
struct fastdiv_values src1_div2; // fastdiv values for ne2
struct fastdiv_values src1_div3; // fastdiv values for ne3
struct fastdiv_values src1_div21; // fastdiv values for ne2 * ne1
struct fastdiv_values src3_div1; // fastdiv values for ne1
struct fastdiv_values src3_div2; // fastdiv values for ne2
struct fastdiv_values src3_div3; // fastdiv values for ne3
struct fastdiv_values src3_div21; // fastdiv values for ne2 * ne1
struct fastdiv_values broadcast_rk2;
struct fastdiv_values broadcast_rk3;
struct fastdiv_values broadcast_rv2;
struct fastdiv_values broadcast_rv3;
struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12
struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11
struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10
struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11
uint32_t flags;
};

View File

@ -49,62 +49,6 @@ struct htp_matmul_context {
struct fastdiv_values mm_div_r3;
};
// vdelta control to replicate first 4x fp32 values across lanes
static const uint8_t __attribute__((aligned(128))) repl_4x_f32[128] = {
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10,
0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20,
0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04,
0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40,
0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04,
0x04, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10,
};
// vdelta control to replicate and interleave first 8x fp32 values across lanes
static const uint8_t __attribute__((aligned(128))) repl_interleave_8x_f32[128] = {
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x00, 0x00, 0x00,
0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20,
0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04,
0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40,
0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x44, 0x44, 0x44,
0x44, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20,
};
// vdelta control to replicate first fp32 value across all elements
static const uint8_t __attribute__((aligned(128))) repl_1x_f32[128] = {
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10,
0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08,
0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08,
0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04,
0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10,
0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
};
// vdelta control to replicate first fp16 value across all elements
static const uint8_t __attribute__((aligned(128))) repl_1x_f16[128] = {
0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02,
0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04,
0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08,
0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x40, 0x40, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02,
0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02,
0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10,
0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
};
// vdelta control to replicate first fp16 value across all elements
static const uint8_t __attribute__((aligned(128))) repl_2x_f16[128] = {
0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
};
// vdelta control to expand first 32 e8m0 values into 32 uint32 elements
static const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = {
0x00, 0x00, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00, 0x00,
@ -2067,10 +2011,10 @@ static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restric
HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements
// Convert to QF32
HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero);
HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero);
HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero);
HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero);
HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); // replicated over all lanes
HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); // replicated over all lanes
HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); // replicated over all lanes
HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); // replicated over all lanes
// Combine and convert to fp16
HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf)));
@ -2080,11 +2024,6 @@ static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restric
HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
// Replicate first fp16 scale across all lanes
HVX_Vector ctrl = *(const HVX_Vector *) repl_2x_f16;
vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl);
vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl);
HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16);
@ -2130,13 +2069,8 @@ static inline void quantize_block_f32_q8x2(float * restrict x, uint8_t * restric
HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
// Compute max and scale
HVX_Vector vmax01_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf));
HVX_Vector vmax23_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx23_hf));
// Replicate first fp16 scale across all lanes
HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_f16;
vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl);
vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl);
HVX_Vector vmax01_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf)); // replicated over all lanes
HVX_Vector vmax23_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx23_hf)); // replicated over all lanes
HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
@ -2179,11 +2113,7 @@ static inline void quantize_block_f32_q8x4(float * restrict x, uint8_t * restric
// Compute max and scale
HVX_Vector vmax_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf));
vmax_hf = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf);
// Replicate first fp16 scale across all lanes
HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_f16;
vmax_hf = Q6_V_vdelta_VV(vmax_hf, ctrl);
vmax_hf = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf); // replicated over all lanes
HVX_Vector vd_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
HVX_Vector vd_hf = Q6_Vhf_equals_Vqf16(vd_qf16);

View File

@ -10,6 +10,7 @@
#include "hex-dma.h"
#include "hvx-utils.h"
#include "hex-fastdiv.h"
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
@ -21,6 +22,9 @@
#define HTP_ROPE_TYPE_NORMAL 0
#define HTP_ROPE_TYPE_NEOX 2
#define HTP_ROPE_SPAD_NROWS 16
#define HTP_ROPE_SPAD_BLOCK (HTP_ROPE_SPAD_NROWS/2)
#define htp_rope_preamble \
const uint32_t ne00 = src0->ne[0]; \
const uint32_t ne01 = src0->ne[1]; \
@ -42,7 +46,7 @@
const uint32_t nb2 = dst->nb[2]; \
const uint32_t nb3 = dst->nb[3];
struct rope_th_ctx {
struct htp_rope_context {
int32_t n_dims;
int32_t mode;
int32_t n_ctx_orig;
@ -57,7 +61,19 @@ struct rope_th_ctx {
float theta_scale;
float corr_dims[2];
uint32_t src0_nrows_per_thread;
size_t spad_stride;
struct htp_ops_context * octx;
size_t src0_row_size;
size_t dst_row_size;
size_t src0_row_size_aligned;
size_t dst_row_size_aligned;
size_t theta_cache_offset;
uint32_t src0_nrows;
uint64_t t_start;
};
static float rope_yarn_ramp(const float low, const float high, const int i0) {
@ -117,64 +133,23 @@ static void rope_corr_dims(int n_dims,
dims[1] = MIN(n_dims - 1, end);
}
static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context * octx) {
memset(rope_ctx, 0, sizeof(struct rope_th_ctx));
static inline void hvx_rope_neox_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) {
const HVX_Vector * restrict vsrc = (const HVX_Vector *) src0;
const HVX_Vector * restrict vtheta = (const HVX_Vector *) theta_cache;
HVX_Vector * restrict vdst = (HVX_Vector *) dst;
const int32_t * op_params = &octx->op_params[0];
uint32_t nvec = (ne / (VLEN_FP32 * 2) * 2); // 2 vecs per loop, step of 2
rope_ctx->n_dims = ((const int32_t *) op_params)[1];
rope_ctx->mode = ((const int32_t *) op_params)[2];
rope_ctx->n_ctx_orig = ((const int32_t *) op_params)[4];
uint32_t he = ne / 2; // half_dims offset in elements
uint32_t hv = he / VLEN_FP32; // half_dims offset in vectors
memcpy(&rope_ctx->freq_base, (int32_t *) op_params + 5, sizeof(float));
memcpy(&rope_ctx->freq_scale, (int32_t *) op_params + 6, sizeof(float));
memcpy(&rope_ctx->ext_factor, (int32_t *) op_params + 7, sizeof(float));
memcpy(&rope_ctx->attn_factor, (int32_t *) op_params + 8, sizeof(float));
memcpy(&rope_ctx->beta_fast, (int32_t *) op_params + 9, sizeof(float));
memcpy(&rope_ctx->beta_slow, (int32_t *) op_params + 10, sizeof(float));
memcpy(&rope_ctx->sections, (int32_t *) op_params + 11, sizeof(int) * 4);
#pragma unroll(2)
for (uint32_t i = 0; i < nvec; i += 2) {
HVX_Vector v0 = vsrc[i/2+0];
HVX_Vector v1 = vsrc[i/2+hv];
rope_ctx->theta_scale = powf(rope_ctx->freq_base, -2.0f / rope_ctx->n_dims);
rope_corr_dims(rope_ctx->n_dims, rope_ctx->n_ctx_orig, rope_ctx->freq_base, rope_ctx->beta_fast,
rope_ctx->beta_slow, rope_ctx->corr_dims);
rope_ctx->octx = octx;
FARF(HIGH, "rope-f32 n_dims:%d, ext_factor:%.6f, theta_scale:%.6f, attn_factor:%.6f\n", rope_ctx->n_dims,
rope_ctx->ext_factor, rope_ctx->theta_scale, rope_ctx->attn_factor);
}
static void hvx_calc_rope_neox_f32(const float * restrict src0,
float * restrict dst,
const int num_elems,
const float * restrict theta_cache) {
// for (int i = 0; i < num_elems; i += 2) {
//const float cos_theta = theta_cache[i + 0];
//const float sin_theta = theta_cache[i + 1];
//const float x0 = src[0];
//const float x1 = src[num_elems/2];
//dst[0] = x0*cos_theta - x1*sin_theta;
//dst[num_elems/2] = x0*sin_theta + x1*cos_theta;
//src += 1;
//dst += 1;
// }
const uint8_t * restrict src0_curr = (const uint8_t *) src0;
const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache;
uint8_t * restrict dst_curr = (uint8_t *) dst;
int step_of_1 = num_elems >> 6; // 6 because we process two vectors at once
int half_size = (sizeof(float) * (num_elems / 2));
for (int i = 0; i < step_of_1; i++) {
HVX_Vector v0 = *(HVX_Vector *) src0_curr;
HVX_Vector v1 = *(HVX_Vector *) (src0_curr + half_size);
HVX_Vector v2 = *(HVX_Vector *) theta_curr;
HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN);
HVX_Vector v2 = vtheta[i+0];
HVX_Vector v3 = vtheta[i+1];
HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta
@ -186,45 +161,34 @@ static void hvx_calc_rope_neox_f32(const float * restrict src0,
HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s);
HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c);
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v4);
*(HVX_Vector *) (dst_curr + half_size) = Q6_Vsf_equals_Vqf32(v5);
vdst[i/2+0] = Q6_Vsf_equals_Vqf32(v4);
vdst[i/2+hv] = Q6_Vsf_equals_Vqf32(v5);
}
src0_curr += VLEN;
theta_curr += 2 * VLEN;
dst_curr += VLEN;
for (uint32_t i = nvec * VLEN_FP32; i < ne; i += 2) {
const float cos_theta = theta_cache[i+0];
const float sin_theta = theta_cache[i+1];
float x0 = src0[i/2];
float x1 = src0[i/2 + he];
dst[i/2] = x0 * cos_theta - x1 * sin_theta;
dst[i/2 + he] = x0 * sin_theta + x1 * cos_theta;
}
}
static void hvx_calc_rope_f32(const float * restrict src0,
float * restrict dst,
const int num_elems,
const float * restrict theta_cache) {
// for (int i = 0; i < num_elems; i += 2) {
//const float cos_theta = theta_cache[i + 0];
//const float sin_theta = theta_cache[i + 1];
static inline void hvx_rope_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) {
const HVX_Vector * restrict vsrc = (const HVX_Vector *) src0;
const HVX_Vector * restrict vtheta = (const HVX_Vector *) theta_cache;
HVX_Vector * restrict vdst = (HVX_Vector *) dst;
//const float x0 = src[0];
//const float x1 = src[1];
uint32_t nvec = (ne / (VLEN_FP32 * 2)) * 2; // 2 vecs per loop, step of two
//dst[0] = x0*cos_theta - x1*sin_theta;
//dst[1] = x0*sin_theta + x1*cos_theta;
#pragma unroll(2)
for (uint32_t i = 0; i < nvec; i+=2) {
HVX_Vector v0 = vsrc[i+0];
HVX_Vector v1 = vsrc[i+1];
//src += 2;
//dst += 2;
// }
const uint8_t * restrict src0_curr = (const uint8_t *) src0;
const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache;
uint8_t * restrict dst_curr = (uint8_t *) dst;
int step_of_1 = num_elems >> 6; // 6 because we process two vectors at once
for (int i = 0; i < step_of_1; i++) {
HVX_Vector v0 = *(HVX_Vector *) src0_curr;
HVX_Vector v1 = *(HVX_Vector *) (src0_curr + VLEN);
HVX_Vector v2 = *(HVX_Vector *) theta_curr;
HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN);
HVX_Vector v2 = vtheta[i+0];
HVX_Vector v3 = vtheta[i+1];
HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(v1, v0, -4); // vx0_x1[0] = x0, vx0_x1[1] = x1
HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta
@ -239,116 +203,65 @@ static void hvx_calc_rope_f32(const float * restrict src0,
HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4);
*(HVX_Vector *) dst_curr = Q6_V_lo_W(vstore);
*(HVX_Vector *) (dst_curr + VLEN) = Q6_V_hi_W(vstore);
vdst[i+0] = Q6_V_lo_W(vstore);
vdst[i+1] = Q6_V_hi_W(vstore);
}
src0_curr += 2 * VLEN;
theta_curr += 2 * VLEN;
dst_curr += 2 * VLEN;
for (uint32_t i = nvec * VLEN_FP32; i < ne; i += 2) {
const float cos_theta = theta_cache[i+0];
const float sin_theta = theta_cache[i+1];
float x0 = src0[i+0];
float x1 = src0[i+1];
dst[i+0] = x0 * cos_theta - x1 * sin_theta;
dst[i+1] = x0 * sin_theta + x1 * cos_theta;
}
}
static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
const uint32_t ir0,
const uint32_t ir1,
int nth,
int ith,
const int opt_path) {
struct htp_ops_context * octx = rope_ctx->octx;
static void inline rope_basic_f32(struct htp_rope_context * rctx, uint8_t * restrict dst, uint8_t * restrict src,
uint32_t nr, uint32_t ne0, const float * restrict theta_cache) {
#pragma unroll(4)
for (uint32_t i = 0; i < nr; i++) {
float * d = (float *) (dst + i * rctx->dst_row_size_aligned);
float * s = (float *) (src + i * rctx->src0_row_size_aligned);
hvx_rope_f32_aa(d, s, rctx->n_dims, theta_cache);
// fill the remain channels with data from src tensor
if (rctx->n_dims < ne0) {
hvx_copy_f32_uu((uint8_t *)(d + rctx->n_dims), (uint8_t *)(s + rctx->n_dims), ne0 - rctx->n_dims);
}
}
}
static void inline rope_neox_f32(struct htp_rope_context * rctx, uint8_t * restrict dst, uint8_t * restrict src,
uint32_t nr, uint32_t ne0, const float * restrict theta_cache) {
#pragma unroll(4)
for (uint32_t i = 0; i < nr; i++) {
float * d = (float *) (dst + i * rctx->dst_row_size_aligned);
float * s = (float *) (src + i * rctx->src0_row_size_aligned);
hvx_rope_neox_f32_aa(d, s, rctx->n_dims, theta_cache);
// fill the remain channels with data from src tensor
if (rctx->n_dims < ne0) {
hvx_copy_f32_uu((uint8_t *)(d + rctx->n_dims), (uint8_t *)(s + rctx->n_dims), ne0 - rctx->n_dims);
}
}
}
static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) {
struct htp_rope_context * rctx = (struct htp_rope_context *) data;
struct htp_ops_context * octx = rctx->octx;
const struct htp_tensor * src0 = &octx->src0;
const struct htp_tensor * src1 = &octx->src1;
const struct htp_tensor * src2 = &octx->src2;
struct htp_tensor * dst = &octx->dst;
const int32_t mode = rope_ctx->mode;
const bool is_neox = mode & HTP_ROPE_TYPE_NEOX;
htp_rope_preamble;
const int32_t * pos = (const int32_t *) src1->data;
float * wp0 = (float *) (octx->src0_spad.data + (ith * nb01));
const float * freq_factors = NULL;
if (src2 != NULL) {
freq_factors = (const float *) src2->data;
}
const uint32_t i1_end = MIN(ir1, ne1);
const int32_t half_dims = rope_ctx->n_dims / 2;
const size_t remain_bytes = (ne0 - rope_ctx->n_dims) * sizeof(float);
for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch
for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len
const int32_t p = pos[i2];
rope_cache_init(p, rope_ctx->freq_scale, freq_factors, rope_ctx->corr_dims, ne0, rope_ctx->ext_factor,
rope_ctx->attn_factor, wp0, rope_ctx->theta_scale);
for (uint32_t i1 = ir0; i1 < i1_end; i1++) { // attn-heads
const float * src = (float *) ((char *) src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01);
float * dst_data = (float *) ((char *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1);
const float * src_loc = src;
float * dst_data_loc = dst_data;
if (1 == opt_path) {
if (is_neox) {
hvx_calc_rope_neox_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
} else {
hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
}
src_loc += rope_ctx->n_dims;
dst_data_loc += rope_ctx->n_dims;
} else {
for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) {
const float cos_theta = wp0[i0 + 0];
const float sin_theta = wp0[i0 + 1];
if (is_neox) {
const float x0 = src_loc[0];
const float x1 = src_loc[half_dims];
dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
dst_data_loc[half_dims] = x0 * sin_theta + x1 * cos_theta;
src_loc += 1;
dst_data_loc += 1;
} else {
const float x0 = src_loc[0];
const float x1 = src_loc[1];
dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
dst_data_loc[1] = x0 * sin_theta + x1 * cos_theta;
src_loc += 2;
dst_data_loc += 2;
}
}
src_loc += (is_neox ? half_dims : 0);
dst_data_loc += (is_neox ? half_dims : 0);
}
// TODO: use simd to speed up the remaining elements copy
memcpy(dst_data_loc, src_loc, remain_bytes);
}
}
}
}
static void rope_job_f32_per_thread(struct rope_th_ctx * rope_ctx, int nth, int ith) {
struct htp_ops_context * octx = rope_ctx->octx;
const struct htp_tensor * src0 = &octx->src0;
const struct htp_tensor * src1 = &octx->src1;
struct htp_tensor * dst = &octx->dst;
htp_rope_preamble;
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
const uint32_t src0_nrows = rctx->src0_nrows;
const uint32_t src0_nrows_per_thread = rctx->src0_nrows_per_thread;
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
@ -358,32 +271,114 @@ static void rope_job_f32_per_thread(struct rope_th_ctx * rope_ctx, int nth, int
return;
}
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
uint64_t tt = HAP_perf_get_qtimer_count();
int is_aligned = 1;
int opt_path = 0;
if ((0 == hex_is_aligned((void *) src0->data, VLEN)) || (0 == hex_is_aligned((void *) src1->data, VLEN)) ||
(0 == hex_is_aligned((void *) dst->data, VLEN))) {
FARF(HIGH, "rope-f32: unaligned addresses in rope op, possibly slower execution\n");
is_aligned = 0;
}
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
opt_path = 1;
const int32_t mode = rctx->mode;
const bool is_neox = mode & HTP_ROPE_TYPE_NEOX;
// VTCM setup
uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
float * theta_cache = (float *) (src0_spad_base);
src0_spad_base = src0_spad_base + rctx->theta_cache_offset;
uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
dma_queue * dma_queue = octx->ctx->dma[ith];
const int32_t * pos = (const int32_t *) src1->data;
const float * freq_factors = src2->data ? (const float *) src2->data : NULL;
uint32_t ir = 0;
uint32_t prev_i2 = (uint32_t) -1;
for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch
for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len
for (uint32_t i1 = 0; i1 < ne1; ) { // attn-heads
if (ir < src0_start_row) { ir++; i1++; continue; }
if (ir >= src0_end_row) goto done;
// Rows in this block
const uint32_t nrows = MIN(src0_end_row - ir, ne1 - i1);
// Depth before prefetch
uint32_t dma_depth = dma_queue_depth(dma_queue);
// FARF(HIGH, "rope-block %u: ir %u n-rows %u dma-depth %u : usec %u", ith, ir, nrows, dma_depth,
// (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));
// Prefetch loop
for (uint32_t pnr = 0, pr = 0; pr < nrows && pr < HTP_ROPE_SPAD_NROWS; pr += pnr) {
pnr = MIN(nrows - pr, HTP_ROPE_SPAD_BLOCK);
uint32_t pi1 = i1 + pr;
uint32_t pir = ir + pr;
// Dummy DMA transaction for sequencing (interleaving dst,src,dst,...)
dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr((void *) dst->data, dst_spad_base + pr * rctx->dst_row_size_aligned), 0, 0, 0);
const uint8_t * src_addr = (const uint8_t *) src0->data + i3 * nb03 + i2 * nb02 + pi1 * nb01;
uint8_t * src_spad = src0_spad_base + pr * rctx->src0_row_size_aligned;
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src_spad, src_addr),
rctx->src0_row_size_aligned, rctx->src0_row_size, pnr);
// FARF(HIGH, "rope-prefetch %u: pr %u i1 %u i2 %u i3 %u src-spad %p src-addr %p pnr %u", ith, pir, pi1, i2, i3, src_spad, src_addr, pnr);
}
// Update theta cache
if (i2 != prev_i2) {
prev_i2 = i2;
const int32_t p = pos[i2];
rope_cache_init(p, rctx->freq_scale, freq_factors, rctx->corr_dims, ne0, rctx->ext_factor, rctx->attn_factor, theta_cache, rctx->theta_scale);
// FARF(HIGH, "rope-theta %u: ir %u i1 %u i2 %u i3 %u cache %p : usec %u", ith, ir, i1, i2, i3, theta_cache,
// (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));
}
// Skip DMA transactions from prev block (if any)
// No need to wait for these since the DMA is setup for in-order processing
for (uint32_t d=0; d < dma_depth; d++) { dma_queue_pop_nowait(dma_queue); }
// Compute loop
for (uint32_t cnr = 0, cr = 0; cr < nrows; cr += cnr, ir += cnr, i1 += cnr) {
// Number of rows to compute
cnr = MIN(nrows - cr, HTP_ROPE_SPAD_BLOCK);
uint8_t * dst_spad = (uint8_t *) dma_queue_pop(dma_queue).src;
uint8_t * src_spad = (uint8_t *) dma_queue_pop(dma_queue).dst;
// FARF(HIGH, "rope-compute %u: ir %u i1 %u i2 %u i3 %u src-spad %p cnr %u : usec %u", ith, ir, i1, i2, i3, src_spad, cnr,
// (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));
if (is_neox) {
rope_neox_f32(rctx, dst_spad, src_spad, cnr, ne0, theta_cache);
} else {
rope_basic_f32(rctx, dst_spad, src_spad, cnr, ne0, theta_cache);
}
uint8_t * dst_addr = (uint8_t *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1;
dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(dst_addr, dst_spad), rctx->dst_row_size, rctx->dst_row_size_aligned, cnr);
// Prefetch more rows (if any)
if ((cr + HTP_ROPE_SPAD_NROWS) < nrows) {
uint32_t pnr = MIN(nrows - (cr + HTP_ROPE_SPAD_NROWS), HTP_ROPE_SPAD_BLOCK);
uint32_t pi1 = i1 + HTP_ROPE_SPAD_NROWS;
uint32_t pir = ir + HTP_ROPE_SPAD_NROWS;
const uint8_t * src_addr = (const uint8_t *) src0->data + i3 * nb03 + i2 * nb02 + pi1 * nb01;
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src_spad, src_addr),
rctx->src0_row_size_aligned, rctx->src0_row_size, pnr);
// FARF(HIGH, "rope-prefetch %u: pr %u i1 %u i2 %u i3 %u src-spad %p src-addr %p pnr %u", ith, pir, pi1, i2, i3, src_spad, src_addr, pnr);
}
}
}
}
}
rope_hex_f32(rope_ctx, src0_start_row, src0_end_row, nth, ith, opt_path);
done:
dma_queue_flush(dma_queue);
tt = HAP_perf_get_qtimer_count() - tt;
t2 = HAP_perf_get_qtimer_count();
FARF(HIGH, "rope-f32: %d/%d/%d: (%u:%u) usec %u\n", ith, nth, opt_path, src0_start_row, src0_end_row,
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
static void rope_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
struct rope_th_ctx * rope_ctx = (struct rope_th_ctx *) data;
rope_job_f32_per_thread(rope_ctx, n, i);
FARF(HIGH, "rope-f32: %d/%d: (%u:%u) usec %u\n", ith, nth, src0_start_row, src0_end_row, (unsigned) HAP_perf_qtimer_count_to_us(tt));
}
static int execute_op_rope_f32(struct htp_ops_context * octx) {
@ -394,17 +389,10 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) {
const struct htp_tensor * src2 = &octx->src2;
struct htp_tensor * dst = &octx->dst;
worker_callback_t op_func;
const char * op_type = NULL;
struct rope_th_ctx rope_ctx;
const char * op_type = "rope-f32";
switch (octx->op) {
case HTP_OP_ROPE:
op_func = rope_job_dispatcher_f32;
op_type = "rope-f32";
init_rope_ctx(&rope_ctx, octx);
break;
default:
@ -415,49 +403,79 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) {
const uint32_t n_threads = octx->n_threads;
const size_t src0_row_size = src0->nb[1];
const size_t src1_row_size = src0_row_size;
const size_t dst_row_size = dst->nb[1];
// VTCM scratchpads for all tensors
// N rows per thread, padded to HVX vector size
octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads;
octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads;
// Aligned row sizes for VTCM
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
const size_t theta_cache_size_aligned = hex_round_up(src0->ne[0] * sizeof(float), 128);
size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
// Calculate spad sizes per thread
size_t src0_spad_per_thread = theta_cache_size_aligned + HTP_ROPE_SPAD_NROWS * src0_row_size_aligned;
size_t dst_spad_per_thread = HTP_ROPE_SPAD_NROWS * dst_row_size_aligned;
size_t spad_per_thread = src0_spad_per_thread + dst_spad_per_thread;
if (src2->ne[0]) {
FARF(HIGH,
"%s: %ux%ux%ux%u (x %ux%ux%ux%u x %ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u "
"dst-spad-size %u\n",
op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1], dst->ne[2],
dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
} else {
FARF(HIGH,
"%s: %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
octx->dst_spad.size);
}
// Make sure the reserved vtcm size is sufficient
if (octx->ctx->vtcm_size < spad_size) {
FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
spad_size);
// Check if we fit in VTCM
size_t total_vtcm_needed = spad_per_thread * n_threads;
if (octx->ctx->vtcm_size < total_vtcm_needed) {
FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, total_vtcm_needed);
return HTP_STATUS_VTCM_TOO_SMALL;
}
octx->src0_spad.data = octx->ctx->vtcm_base;
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
// Assign sizes
octx->src0_spad.size_per_thread = src0_spad_per_thread;
octx->dst_spad.size_per_thread = dst_spad_per_thread;
octx->src0_spad.size = n_threads * src0_spad_per_thread;
octx->dst_spad.size = n_threads * dst_spad_per_thread;
octx->src1_spad.size = 0;
// Assign pointers
octx->src0_spad.data = octx->ctx->vtcm_base;
octx->src1_spad.data = NULL;
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
// Fill context
struct htp_rope_context rctx;
memset(&rctx, 0, sizeof(struct htp_rope_context));
rctx.t_start = HAP_perf_get_qtimer_count();
rctx.octx = octx;
const int32_t * op_params = &octx->op_params[0];
rctx.n_dims = ((const int32_t *) op_params)[1];
rctx.mode = ((const int32_t *) op_params)[2];
rctx.n_ctx_orig = ((const int32_t *) op_params)[4];
memcpy(&rctx.freq_base, (int32_t *) op_params + 5, sizeof(float));
memcpy(&rctx.freq_scale, (int32_t *) op_params + 6, sizeof(float));
memcpy(&rctx.ext_factor, (int32_t *) op_params + 7, sizeof(float));
memcpy(&rctx.attn_factor, (int32_t *) op_params + 8, sizeof(float));
memcpy(&rctx.beta_fast, (int32_t *) op_params + 9, sizeof(float));
memcpy(&rctx.beta_slow, (int32_t *) op_params + 10, sizeof(float));
memcpy(&rctx.sections, (int32_t *) op_params + 11, sizeof(int) * 4);
rctx.theta_scale = powf(rctx.freq_base, -2.0f / rctx.n_dims);
rope_corr_dims(rctx.n_dims, rctx.n_ctx_orig, rctx.freq_base, rctx.beta_fast, rctx.beta_slow, rctx.corr_dims);
rctx.src0_row_size = src0_row_size;
rctx.dst_row_size = dst_row_size;
rctx.src0_row_size_aligned = src0_row_size_aligned;
rctx.dst_row_size_aligned = dst_row_size_aligned;
rctx.theta_cache_offset = theta_cache_size_aligned;
uint32_t ne0 = dst->ne[0];
uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
rctx.src0_nrows = src0_nrows;
FARF(HIGH, "rope-f32 n-rows %u n-dims %d ne0 %u ext-factor %.6f theta-scale %.6f attn-factor %.6f\n", rctx.src0_nrows, rctx.n_dims, ne0,
rctx.ext_factor, rctx.theta_scale, rctx.attn_factor);
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
uint32_t n_jobs = MIN(n_threads, src0_nrows);
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
worker_pool_run_func(octx->ctx->worker_pool, op_func, &rope_ctx, n_jobs);
uint32_t n_jobs = MIN(n_threads, src0_nrows);
rctx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
worker_pool_run_func(octx->ctx->worker_pool, rope_job_f32, &rctx, n_jobs);
}
return err;

View File

@ -43,11 +43,21 @@
\
const uint32_t nr = ne01;
static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
struct htp_set_rows_context {
struct htp_ops_context * octx;
struct fastdiv_values div_ne12;
struct fastdiv_values div_ne11;
uint32_t src0_nrows_per_thread;
};
static void set_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) {
struct htp_set_rows_context * srctx = (struct htp_set_rows_context *)data;
struct htp_ops_context * octx = srctx->octx;
set_rows_preamble;
// parallelize by rows of src0
const uint32_t dr = octx->src0_nrows_per_thread;
const uint32_t dr = srctx->src0_nrows_per_thread;
const uint32_t ir0 = dr * ith;
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
@ -56,8 +66,8 @@ static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth,
for (uint32_t i03 = 0; i03 < ne03; ++i03) {
for (uint32_t i02 = 0; i02 < ne02; ++i02) {
for (uint32_t i = ir0; i < ir1; ++i) {
const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
const uint32_t i12 = fastmodulo(i03, ne12, &srctx->div_ne12);
const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11);
const uint32_t i10 = i;
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
@ -76,15 +86,16 @@ static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth,
}
}
}
return HTP_STATUS_OK;
}
static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, const int ith) {
static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *data) {
struct htp_set_rows_context * srctx = (struct htp_set_rows_context *)data;
struct htp_ops_context * octx = srctx->octx;
set_rows_preamble;
// parallelize by rows of src0
const uint32_t dr = octx->src0_nrows_per_thread;
const uint32_t dr = srctx->src0_nrows_per_thread;
const uint32_t ir0 = dr * ith;
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
@ -93,8 +104,8 @@ static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth,
for (uint32_t i03 = 0; i03 < ne03; ++i03) {
for (uint32_t i02 = 0; i02 < ne02; ++i02) {
for (uint32_t i = ir0; i < ir1; ++i) {
const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
const uint32_t i12 = fastmodulo(i03, ne12, &srctx->div_ne12);
const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11);
const uint32_t i10 = i;
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
@ -112,16 +123,6 @@ static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth,
}
}
}
return HTP_STATUS_OK;
}
static void set_rows_work_f16_f32(unsigned int n, unsigned int i, void *data) {
set_rows_thread_f16_f32((struct htp_ops_context *) data, n, i);
}
static void set_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
set_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
}
int op_set_rows(struct htp_ops_context * octx) {
@ -143,18 +144,20 @@ int op_set_rows(struct htp_ops_context * octx) {
return HTP_STATUS_OK;
}
octx->set_rows_div_ne12 = init_fastdiv_values(ne12);
octx->set_rows_div_ne11 = init_fastdiv_values(ne11);
struct htp_set_rows_context srctx;
srctx.octx = octx;
srctx.div_ne12 = init_fastdiv_values(ne12);
srctx.div_ne11 = init_fastdiv_values(ne11);
const uint32_t n_jobs = MIN(nr, octx->n_threads);
octx->src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
srctx.src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
switch(octx->dst.type) {
case HTP_TYPE_F32:
worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f32_f32, octx, n_jobs);
worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f32_f32, &srctx, n_jobs);
break;
case HTP_TYPE_F16:
worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f16_f32, octx, n_jobs);
worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f16_f32, &srctx, n_jobs);
break;
default:
return HTP_STATUS_NO_SUPPORT;

View File

@ -10,6 +10,7 @@
#include "hex-dma.h"
#include "hvx-utils.h"
#include "hex-fastdiv.h"
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
@ -48,7 +49,7 @@
const uint32_t nb2 = dst->nb[2]; \
const uint32_t nb3 = dst->nb[3];
struct softmax_th_ctx {
struct htp_softmax_context {
bool use_f16;
bool use_src1;
uint32_t n_head;
@ -59,28 +60,48 @@ struct softmax_th_ctx {
float m0;
float m1;
uint32_t src0_nrows_per_thread;
struct fastdiv_values fastdiv_ne01;
struct fastdiv_values fastdiv_ne02;
struct fastdiv_values fastdiv_ne12; // For mask broadcasting
struct fastdiv_values fastdiv_ne13; // For mask broadcasting
size_t spad_stride;
struct htp_ops_context * octx;
};
static void init_softmax_ctx(struct softmax_th_ctx * softmax_ctx, struct htp_ops_context * octx) {
static void init_softmax_ctx(struct htp_softmax_context * smctx, struct htp_ops_context * octx) {
const struct htp_tensor * src0 = &octx->src0;
const struct htp_tensor * src1 = &octx->src1;
memset(softmax_ctx, 0, sizeof(struct softmax_th_ctx));
memset(smctx, 0, sizeof(struct htp_softmax_context));
memcpy(&softmax_ctx->scale, (float *) octx->op_params, sizeof(float));
memcpy(&softmax_ctx->max_bias, (float *) octx->op_params + 1, sizeof(float));
memcpy(&smctx->scale, (float *) octx->op_params, sizeof(float));
memcpy(&smctx->max_bias, (float *) octx->op_params + 1, sizeof(float));
softmax_ctx->n_head = src0->ne[2];
softmax_ctx->n_head_log2 = 1u << (uint32_t) floor(log2(softmax_ctx->n_head));
smctx->n_head = src0->ne[2];
smctx->n_head_log2 = 1u << (uint32_t) floor(log2(smctx->n_head));
softmax_ctx->m0 = powf(2.0f, -(softmax_ctx->max_bias) / softmax_ctx->n_head_log2);
softmax_ctx->m1 = powf(2.0f, -(softmax_ctx->max_bias / 2.0f) / softmax_ctx->n_head_log2);
smctx->m0 = powf(2.0f, -(smctx->max_bias) / smctx->n_head_log2);
smctx->m1 = powf(2.0f, -(smctx->max_bias / 2.0f) / smctx->n_head_log2);
softmax_ctx->use_src1 = (src1->ne[0] != 0);
softmax_ctx->use_f16 = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16);
smctx->use_src1 = (src1->ne[0] != 0);
smctx->use_f16 = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16);
softmax_ctx->octx = octx;
smctx->octx = octx;
// Initialize fastdiv values
const uint32_t ne01 = src0->ne[1];
const uint32_t ne02 = src0->ne[2];
if (ne01 > 0) smctx->fastdiv_ne01 = init_fastdiv_values(ne01);
if (ne02 > 0) smctx->fastdiv_ne02 = init_fastdiv_values(ne02);
const uint32_t ne12 = (src1->ne[0]) ? src1->ne[2] : 1;
const uint32_t ne13 = (src1->ne[0]) ? src1->ne[3] : 1;
if (ne12 > 0) smctx->fastdiv_ne12 = init_fastdiv_values(ne12);
if (ne13 > 0) smctx->fastdiv_ne13 = init_fastdiv_values(ne13);
}
static void hvx_fast_softmax_prep_f32(const uint8_t * restrict src,
@ -139,8 +160,7 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src,
max_vec = Q6_Vsf_vmax_VsfVsf(max_vec, v1);
}
HVX_Vector v = hvx_vec_reduce_max_f32(max_vec);
max_vec = hvx_vec_repl4(v);
max_vec = hvx_vec_reduce_max_f32(max_vec); // replicated over all lanes
#pragma unroll(4)
for (int i = 0; i < step_of_1; i++) {
@ -154,8 +174,7 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src,
v_pad[i] = v3;
}
v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_vec));
sum_vec = hvx_vec_repl4(v);
sum_vec = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_vec)); // replicated over all lanes
HVX_VectorPred pos_sum = Q6_Q_vcmp_gt_VwVw(sum_vec, zero_v);
HVX_Vector v4 = hvx_vec_inverse_f32(sum_vec);
@ -183,83 +202,9 @@ static float hvx_softmax_f32(const uint8_t * restrict src,
return sum;
}
static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ctx, int opt_path) {
struct htp_ops_context * octx = softmax_ctx->octx;
const struct htp_tensor * src0 = &octx->src0;
const struct htp_tensor * src1 = &octx->src1;
const struct htp_tensor * dst = &octx->dst;
htp_softmax_preamble3;
uint8_t * src0_spad_data = octx->src0_spad.data + (ith * nb01);
uint8_t * src1_spad_data = octx->src1_spad.data + (ith * nb01);
uint8_t * dst_spad_data = octx->dst_spad.data + (ith * nb1);
float * wp0 = (float *) src0_spad_data;
float * wp1 = (float *) src1_spad_data;
float * wp2 = (float *) dst_spad_data;
for (uint32_t i03 = 0; i03 < ne03; i03++) {
for (uint32_t i02 = 0; i02 < ne02; i02++) {
for (uint32_t i01 = ith; i01 < ne01; i01 += nth) {
const uint32_t i11 = i01;
const uint32_t i12 = i02 % ne12;
const uint32_t i13 = i03 % ne13;
// ALiBi
const uint32_t h = i02; // head
const float slope = (softmax_ctx->max_bias > 0.0f) ?
h < softmax_ctx->n_head_log2 ?
powf(softmax_ctx->m0, h + 1) :
powf(softmax_ctx->m1, 2 * (h - softmax_ctx->n_head_log2) + 1) :
1.0f;
float * sp = (float *) ((char *) octx->src0.data + i01 * nb01 + i02 * nb02 + i03 * nb03);
float * dp = (float *) ((char *) octx->dst.data + i01 * nb1 + i02 * nb2 + i03 * nb3);
// broadcast the mask across rows
__fp16 * mp_f16 = (softmax_ctx->use_src1) ?
(__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
NULL;
float * mp_f32 = (softmax_ctx->use_src1) ?
(float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
NULL;
if ((1 == opt_path) && (mp_f32) && !(softmax_ctx->use_f16)) {
hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale,
(const uint8_t *) mp_f32, slope);
} else {
hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, softmax_ctx->scale);
if (mp_f32) {
if (softmax_ctx->use_f16) {
for (int i = 0; i < ne00; ++i) {
wp0[i] += slope * (float) mp_f16[i];
}
} else {
for (int i = 0; i < ne00; ++i) {
wp0[i] += slope * mp_f32[i];
}
}
}
}
if (1 == opt_path) {
hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00);
} else {
float max = hvx_reduce_max_f32((const uint8_t *) wp0, ne00);
float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max);
sum = sum > 0.0 ? (1.0 / sum) : 1;
hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum);
}
}
}
}
}
static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int nth, int ith) {
struct htp_ops_context * octx = softmax_ctx->octx;
static void softmax_job_f32(unsigned int nth, unsigned int ith, void * data) {
struct htp_softmax_context * smctx = (struct htp_softmax_context *) data;
struct htp_ops_context * octx = smctx->octx;
const struct htp_tensor * src0 = &octx->src0;
const struct htp_tensor * src1 = &octx->src1;
@ -268,7 +213,7 @@ static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int
htp_softmax_preamble3;
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
const uint32_t src0_nrows_per_thread = smctx->src0_nrows_per_thread;
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
@ -291,20 +236,103 @@ static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int
opt_path = 1;
}
softmax_htp_f32(nth, ith, softmax_ctx, opt_path);
uint8_t * src0_spad_data = octx->src0_spad.data + (ith * smctx->spad_stride);
uint8_t * src1_spad_data = octx->src1_spad.data + (ith * smctx->spad_stride);
uint8_t * dst_spad_data = octx->dst_spad.data + (ith * smctx->spad_stride);
float * wp0 = (float *) src0_spad_data;
float * wp1 = (float *) src1_spad_data;
float * wp2 = (float *) dst_spad_data;
uint32_t prev_i2 = (uint32_t)-1;
float slope = 1.0f;
for (uint32_t r = src0_start_row; r < src0_end_row; ++r) {
uint32_t i1 = fastmodulo(r, ne01, &smctx->fastdiv_ne01);
uint32_t r_div_ne01 = fastdiv(r, &smctx->fastdiv_ne01);
uint32_t i2 = fastmodulo(r_div_ne01, ne02, &smctx->fastdiv_ne02);
uint32_t i3 = fastdiv(r_div_ne01, &smctx->fastdiv_ne02);
// Map to original logic indices
// i01 = i1
// i02 = i2
// i03 = i3
const uint32_t i11 = i1;
// const uint32_t i12 = i2 % ne12;
// const uint32_t i13 = i3 % ne13;
uint32_t i12, i13;
if (ne12 == ne02) {
i12 = i2;
} else {
i12 = fastmodulo(i2, ne12, &smctx->fastdiv_ne12);
}
if (ne13 == ne03) {
i13 = i3;
} else {
i13 = fastmodulo(i3, ne13, &smctx->fastdiv_ne13);
}
// ALiBi
if (i2 != prev_i2) {
const uint32_t h = i2; // head
slope = (smctx->max_bias > 0.0f) ?
h < smctx->n_head_log2 ?
powf(smctx->m0, h + 1) :
powf(smctx->m1, 2 * (h - smctx->n_head_log2) + 1) :
1.0f;
prev_i2 = i2;
}
float * sp = (float *) ((char *) octx->src0.data + i1 * nb01 + i2 * nb02 + i3 * nb03);
float * dp = (float *) ((char *) octx->dst.data + i1 * nb1 + i2 * nb2 + i3 * nb3);
// broadcast the mask across rows
__fp16 * mp_f16 = (smctx->use_src1) ?
(__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
NULL;
float * mp_f32 = (smctx->use_src1) ?
(float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
NULL;
if ((1 == opt_path) && (mp_f32) && !(smctx->use_f16)) {
hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, smctx->scale,
(const uint8_t *) mp_f32, slope);
} else {
hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, smctx->scale);
if (mp_f32) {
if (smctx->use_f16) {
for (int i = 0; i < ne00; ++i) {
wp0[i] += slope * (float) mp_f16[i];
}
} else {
for (int i = 0; i < ne00; ++i) {
wp0[i] += slope * mp_f32[i];
}
}
}
}
if (1 == opt_path) {
hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00);
} else {
float max = hvx_reduce_max_f32((const uint8_t *) wp0, ne00);
float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max);
sum = sum > 0.0 ? (1.0 / sum) : 1;
hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum);
}
}
t2 = HAP_perf_get_qtimer_count();
FARF(HIGH, "softmax-f32 %d/%d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
softmax_ctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13,
smctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13,
ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
static void softmax_job_dispatcher_f32(unsigned int n, unsigned int i, void * p_data) {
struct softmax_th_ctx * p_softmax_ctx = (struct softmax_th_ctx *) p_data;
softmax_job_f32_per_thread(p_softmax_ctx, n, i);
}
static int execute_op_softmax_f32(struct htp_ops_context * octx) {
int err = HTP_STATUS_OK;
@ -312,17 +340,12 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) {
const struct htp_tensor * src1 = &octx->src1;
struct htp_tensor * dst = &octx->dst;
worker_callback_t op_func;
const char * op_type = NULL;
struct softmax_th_ctx softmax_ctx;
struct htp_softmax_context smctx;
const char * op_type = "softmax-f32";
switch (octx->op) {
case HTP_OP_SOFTMAX:
op_func = softmax_job_dispatcher_f32;
op_type = "softmax-f32";
init_softmax_ctx(&softmax_ctx, octx);
init_softmax_ctx(&smctx, octx);
break;
default:
@ -342,6 +365,9 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) {
octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads;
// Use stride for calculating offset
smctx.spad_stride = hex_round_up(src0_row_size, 128);
size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
if (src1->ne[0]) {
@ -371,8 +397,8 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) {
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
uint32_t n_jobs = MIN(n_threads, src0_nrows);
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
worker_pool_run_func(octx->ctx->worker_pool, op_func, &softmax_ctx, n_jobs);
smctx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_jobs);
}
return err;

View File

@ -17,7 +17,6 @@
#include "htp-msg.h"
#include "htp-ops.h"
#define sum_rows_preamble \
struct htp_tensor *src0 = &octx->src0;\
struct htp_tensor *dst = &octx->dst; \
@ -42,53 +41,54 @@
const uint32_t nb2 = dst->nb[2]; \
const uint32_t nb3 = dst->nb[3]; \
static int sum_rows_thread_f32(struct htp_ops_context * octx, const int nth, const int ith) {
sum_rows_preamble;
struct sum_rows_context {
const uint8_t * src_data;
uint8_t * dst_data;
uint32_t ne00;
size_t src_stride;
size_t dst_stride;
uint32_t rows_per_thread;
uint32_t total_rows;
bool opt_path;
};
const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
const size_t src0_row_size = nb01;
const size_t dst_row_size = nb1;
static void sum_rows_thread_f32(unsigned int nth, unsigned int ith, void *data) {
const struct sum_rows_context * smctx = (const struct sum_rows_context *) data;
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
const uint32_t rows_per_thread = smctx->rows_per_thread;
const uint32_t total_rows = smctx->total_rows;
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
const uint32_t start_row = rows_per_thread * ith;
const uint32_t end_row = MIN(start_row + rows_per_thread, total_rows);
// no work for this thread
if (src0_start_row >= src0_end_row) {
return HTP_STATUS_OK;
if (start_row >= end_row) {
return;
}
int opt_path = 0;
if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) {
opt_path = 1;
}
const size_t src_stride = smctx->src_stride;
const size_t dst_stride = smctx->dst_stride;
const uint32_t ne00 = smctx->ne00;
const bool opt_path = smctx->opt_path;
const uint8_t * restrict data_src = (const uint8_t *) src0->data;
uint8_t * restrict data_dst = (uint8_t *) dst->data;
const float * restrict src_th = (const float *) (smctx->src_data + (start_row * src_stride));
float * restrict dst_th = (float *) (smctx->dst_data + (start_row * dst_stride));
const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size));
float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size));
// Calculate actual number of rows for this thread
const uint32_t n_rows = end_row - start_row;
for (uint32_t ir = 0; ir < src0_nrows_per_thread; ir++) {
const float * restrict src_local = src_th + (ir * ne00);
for (uint32_t ir = 0; ir < n_rows; ir++) {
const float * restrict src_local = src_th + (ir * (src_stride / sizeof(float)));
if (ir + 1 < src0_nrows_per_thread) {
hex_l2fetch(src_local + ne00, src0_row_size, src0_row_size, 1);
if (ir + 1 < n_rows) {
hex_l2fetch(src_local + (src_stride / sizeof(float)), src_stride, src_stride, 1);
}
if (1 == opt_path) {
if (opt_path) {
dst_th[ir] = hvx_reduce_sum_f32_a((const uint8_t *) src_local, ne00);
} else {
dst_th[ir] = hvx_reduce_sum_f32((const uint8_t *) src_local, ne00);
}
}
return HTP_STATUS_OK;
}
static void sum_rows_work_f32(unsigned int n, unsigned int i, void *data) {
sum_rows_thread_f32((struct htp_ops_context *) data, n, i);
}
int op_sum_rows(struct htp_ops_context * octx) {
@ -106,10 +106,25 @@ int op_sum_rows(struct htp_ops_context * octx) {
const uint32_t src0_nrows = ne01 * ne02 * ne03;
uint32_t n_jobs = MIN(n_threads, src0_nrows);
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
uint32_t rows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
worker_pool_run_func(octx->ctx->worker_pool, sum_rows_work_f32, octx, n_jobs);
bool opt_path = false;
if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) {
opt_path = true;
}
struct sum_rows_context smctx = {
.src_data = (const uint8_t *) src0->data,
.dst_data = (uint8_t *) dst->data,
.ne00 = ne00,
.src_stride = nb01,
.dst_stride = nb1,
.rows_per_thread = rows_per_thread,
.total_rows = src0_nrows,
.opt_path = opt_path,
};
worker_pool_run_func(octx->ctx->worker_pool, sum_rows_thread_f32, &smctx, n_jobs);
return HTP_STATUS_OK;
}

View File

@ -17,6 +17,28 @@
#include "htp-msg.h"
#include "htp-ops.h"
struct htp_unary_context {
struct htp_ops_context * octx;
// Precomputed values
const uint8_t * data_src0;
uint8_t * data_dst;
size_t src0_row_size;
size_t dst_row_size;
size_t src0_row_size_aligned;
size_t dst_row_size_aligned;
size_t src0_spad_half_size;
size_t dst_spad_half_size;
uint32_t block;
uint32_t src0_nrows;
uint32_t src0_nrows_per_thread;
uint32_t nc;
};
#define htp_unary_preamble \
const uint32_t ne00 = src->ne[0]; \
const uint32_t ne01 = src->ne[1]; \
@ -57,8 +79,7 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
}
HVX_Vector reduced_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v));
sum_v = hvx_vec_repl4(reduced_sum);
sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); // replicated over all lanes
HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems);
HVX_Vector denom_v = hvx_vec_inverse_f32(t_v);
@ -75,128 +96,95 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
}
}
static void scale_htp_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
const uint32_t num_rows,
const uint32_t row_elems,
const size_t row_size,
int32_t * op_params,
int opt_path) {
static void scale_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
const uint32_t num_rows,
const uint32_t row_elems,
const size_t row_size,
int32_t * op_params) {
float scale = 0.f;
float bias = 0.f;
memcpy(&scale, &op_params[0], sizeof(float));
memcpy(&bias, &op_params[1], sizeof(float));
for (uint32_t ir = 0; ir < num_rows; ir++) {
const float * restrict src_local = src + (ir * row_elems);
float * restrict dst_local = dst + (ir * row_elems);
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
if (ir + 1 < num_rows) {
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
}
hvx_scale_offset_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias);
hvx_scale_offset_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias);
}
}
static void rms_norm_htp_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
const uint32_t num_rows,
const uint32_t row_elems,
const size_t row_size,
int32_t * op_params,
int opt_path) {
static void rms_norm_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
const uint32_t num_rows,
const uint32_t row_elems,
const size_t row_size,
int32_t * op_params) {
float epsilon = 0.f;
memcpy(&epsilon, op_params, sizeof(float));
for (uint32_t ir = 0; ir < num_rows; ir++) {
const float * restrict src_local = src + (ir * row_elems);
float * restrict dst_local = dst + (ir * row_elems);
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
if (ir + 1 < num_rows) {
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
}
if (1 == opt_path) {
hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon);
} else {
float sum = hvx_sum_of_squares_f32((const uint8_t *) src_local, row_elems);
const float mean = sum / row_elems;
const float scale = 1.0f / sqrtf(mean + epsilon);
hvx_scale_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale);
}
hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon);
}
}
static void sqr_htp_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
const uint32_t num_rows,
const uint32_t row_elems,
const size_t row_size,
int32_t * op_params,
int opt_path) {
static void sqr_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
const uint32_t num_rows,
const uint32_t row_elems,
const size_t row_size,
int32_t * op_params) {
for (uint32_t ir = 0; ir < num_rows; ir++) {
const float * restrict src_local = src + (ir * row_elems);
float * restrict dst_local = dst + (ir * row_elems);
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
if (ir + 1 < num_rows) {
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
}
if (1 == opt_path) {
hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
} else {
hvx_sqr_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
}
hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
}
}
static void sqrt_htp_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
const uint32_t num_rows,
const uint32_t row_elems,
const size_t row_size,
int32_t * op_params,
int opt_path) {
static void sqrt_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
const uint32_t num_rows,
const uint32_t row_elems,
const size_t row_size,
int32_t * op_params) {
for (uint32_t ir = 0; ir < num_rows; ir++) {
const float * restrict src_local = src + (ir * row_elems);
float * restrict dst_local = dst + (ir * row_elems);
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
if (ir + 1 < num_rows) {
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
}
if (1 == opt_path) {
hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
} else {
hvx_sqrt_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
}
hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
}
}
static void unary_job_f32_per_thread(const struct htp_tensor * src,
struct htp_tensor * dst,
uint8_t * spad,
int htp_op,
int32_t * op_params,
uint32_t nth,
uint32_t ith,
uint32_t src0_nrows_per_thread) {
static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
const struct htp_unary_context * uctx = (const struct htp_unary_context *) data;
struct htp_ops_context * octx = uctx->octx;
const struct htp_tensor * src = &octx->src0;
const struct htp_tensor * dst = &octx->dst;
htp_unary_preamble;
const size_t src0_row_size = nb01;
const size_t dst_row_size = nb1;
int htp_op = octx->op;
int32_t * op_params = octx->op_params;
uint32_t src0_nrows_per_thread = uctx->src0_nrows_per_thread;
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
const size_t src0_row_size = uctx->src0_row_size;
const size_t dst_row_size = uctx->dst_row_size;
const size_t src0_row_size_aligned = uctx->src0_row_size_aligned;
const size_t dst_row_size_aligned = uctx->dst_row_size_aligned;
const uint32_t src0_nrows = uctx->src0_nrows;
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
@ -208,79 +196,104 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src,
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
int is_aligned = 1;
int opt_path = 0;
if ((0 == hex_is_aligned((void *) src->data, VLEN)) || (0 == hex_is_aligned((void *) dst->data, VLEN))) {
is_aligned = 0;
}
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
opt_path = 1;
const uint8_t * restrict data_src = uctx->data_src0;
uint8_t * restrict data_dst = uctx->data_dst;
uint8_t * src0_spad_data = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
uint8_t * dst_spad_data = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
size_t src0_spad_half_size = uctx->src0_spad_half_size;
size_t dst_spad_half_size = uctx->dst_spad_half_size;
const int BLOCK = uctx->block;
if (BLOCK == 0) {
FARF(ERROR, "unary-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
octx->src0_spad.size_per_thread, src0_row_size_aligned);
return;
}
const uint8_t * restrict data_src = (const uint8_t *) src->data;
uint8_t * restrict data_dst = (uint8_t *) dst->data;
dma_queue * dma_queue = octx->ctx->dma[ith];
const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size));
float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size));
uint8_t * restrict spad_th = (uint8_t *) spad + (ith * nb01);
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
switch (htp_op) {
case HTP_OP_RMS_NORM:
rms_norm_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
break;
case HTP_OP_SCALE:
scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
break;
case HTP_OP_SQR:
sqr_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
break;
case HTP_OP_SQRT:
sqrt_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
break;
// Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
dma_queue_push_vtcm_to_ddr(dma_queue,
dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
dst_row_size, dst_row_size_aligned, 0);
default:
break;
dma_queue_push_ddr_to_vtcm(dma_queue,
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + (ir * src0_row_size)),
src0_row_size_aligned, src0_row_size, block_size);
}
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
float * dst_spad = (float *) dma_queue_pop(dma_queue).src;
float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
// Process block in VTCM
switch (htp_op) {
case HTP_OP_RMS_NORM:
rms_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
break;
case HTP_OP_SCALE:
scale_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
break;
case HTP_OP_SQR:
sqr_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
break;
case HTP_OP_SQRT:
sqrt_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
break;
default:
break;
}
dma_queue_push_vtcm_to_ddr(dma_queue,
dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad),
dst_row_size, dst_row_size_aligned, block_size);
// prefetch N+2 loop iteration if any
const uint32_t pref_block = (ir + BLOCK * 2);
if (pref_block < src0_end_row) {
const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
dma_queue_push_ddr_to_vtcm(dma_queue,
dma_make_ptr(src0_spad, data_src + (pref_block * src0_row_size)),
src0_row_size_aligned, src0_row_size, pref_block_size);
}
}
dma_queue_flush(dma_queue);
t2 = HAP_perf_get_qtimer_count();
FARF(HIGH, "unary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, src->ne[0],
FARF(HIGH, "unary-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, src->ne[0],
src->ne[1], src->ne[2], src->ne[3], src0_start_row, src0_end_row, dst->ne[0], dst->ne[1], dst->ne[2],
dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
static void unary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = (struct htp_ops_context *) data;
unary_job_f32_per_thread(&octx->src0, &octx->dst, octx->src0_spad.data, octx->op, octx->op_params, n, i,
octx->src0_nrows_per_thread);
}
static int execute_op_unary_f32(struct htp_ops_context * octx) {
int err = HTP_STATUS_OK;
const struct htp_tensor * src0 = &octx->src0;
struct htp_tensor * dst = &octx->dst;
worker_callback_t unary_op_func;
const char * op_type = NULL;
const char * op_type = NULL;
switch (octx->op) {
case HTP_OP_RMS_NORM:
unary_op_func = unary_job_dispatcher_f32;
op_type = "rmsnorm-f32";
op_type = "rmsnorm-f32";
break;
case HTP_OP_SCALE:
unary_op_func = unary_job_dispatcher_f32;
op_type = "scale-f32";
op_type = "scale-f32";
break;
case HTP_OP_SQR:
unary_op_func = unary_job_dispatcher_f32;
op_type = "sqr-f32";
op_type = "sqr-f32";
break;
case HTP_OP_SQRT:
unary_op_func = unary_job_dispatcher_f32;
op_type = "sqrt-f32";
op_type = "sqrt-f32";
break;
default:
@ -294,32 +307,61 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
const size_t src0_row_size = src0->nb[1];
const size_t dst_row_size = dst->nb[1];
// VTCM scratchpads for all tensors
octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads;
octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
size_t spad_size = octx->src0_spad.size + octx->dst_spad.size;
// VTCM scratchpads for all tensors
// N rows per thread, padded to HVX vector size
// Double buffering requires 2x size per buffer
size_t spad_size_per_row = 2 * (src0_row_size_aligned + dst_row_size_aligned);
size_t vtcm_row_per_thread = (octx->ctx->vtcm_size)/ (n_threads * spad_size_per_row);
// Make sure the reserved vtcm size is sufficient
if (vtcm_row_per_thread == 0) {
FARF(ERROR, "unary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
spad_size_per_row * n_threads);
return HTP_STATUS_VTCM_TOO_SMALL;
}
octx->src0_spad.size_per_thread = src0_row_size_aligned * vtcm_row_per_thread * 2;
octx->dst_spad.size_per_thread = dst_row_size_aligned * vtcm_row_per_thread * 2;
octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread;
octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread;
octx->src0_spad.data = octx->ctx->vtcm_base;
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
// Make sure the reserved vtcm size is sufficient
if (octx->ctx->vtcm_size < spad_size) {
FARF(ERROR, "unary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
spad_size);
return HTP_STATUS_VTCM_TOO_SMALL;
}
octx->src0_spad.data = octx->ctx->vtcm_base;
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
uint32_t n_jobs = MIN(n_threads, src0_nrows);
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
struct htp_unary_context uctx = {
.octx = octx,
.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs,
.src0_nrows = src0_nrows,
worker_pool_run_func(octx->ctx->worker_pool, unary_op_func, octx, n_jobs);
.data_src0 = (const uint8_t *)src0->data,
.data_dst = (uint8_t *)dst->data,
.src0_row_size = src0_row_size,
.dst_row_size = dst_row_size,
.src0_row_size_aligned = src0_row_size_aligned,
.dst_row_size_aligned = dst_row_size_aligned,
.src0_spad_half_size = octx->src0_spad.size_per_thread / 2,
.dst_spad_half_size = octx->dst_spad.size_per_thread / 2,
.block = (octx->src0_spad.size_per_thread / 2) / src0_row_size_aligned,
.nc = src0->ne[0],
};
worker_pool_run_func(octx->ctx->worker_pool, unary_job_f32_per_thread, &uctx, n_jobs);
}
return err;

View File

@ -54,6 +54,6 @@ adb $adbserial $adbhost shell " \
$verbose $experimental $sched $opmask $profile $nhvx $ndev $hb \
./$branch/bin/llama-cli --no-mmap -m $basedir/../gguf/$model \
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
--ctx-size 8192 --batch-size 128 -fa on \
-ngl 99 --device $device $cli_opts $@ \
--ctx-size 8192 --ubatch-size 256 -fa on \
-ngl 99 --device $device $cli_opts $@ \
"

View File

@ -54,6 +54,6 @@ adb $adbserial $adbhost shell " \
$verbose $experimental $sched $opmask $profile $nhvx $ndev $hb \
./$branch/bin/llama-completion --no-mmap -m $basedir/../gguf/$model \
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
--ctx-size 8192 --batch-size 128 -fa on \
-ngl 99 -no-cnv --device $device $cli_opts $@ \
--ctx-size 8192 --ubatch-size 256 -fa on \
-ngl 99 -no-cnv --device $device $cli_opts $@ \
"

View File

@ -58,11 +58,11 @@ adb $adbserial $adbhost shell " \
cd $basedir; ulimit -c unlimited; \
LD_LIBRARY_PATH=$basedir/$branch/lib \
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
$verbose $experimental $sched $opmask $profile $nhvx $ndev $mtmd_backend \
./$branch/bin/llama-mtmd-cli --no-mmap -m $basedir/../gguf/$model \
--mmproj $basedir/../gguf/$mmproj \
--image $basedir/../gguf/$image \
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
--ctx-size 8192 --batch-size 128 -ctk q8_0 -ctv q8_0 -fa on \
-ngl 99 --device $device -v $cli_opts $@ \
$verbose $experimental $sched $opmask $profile $nhvx $ndev $mtmd_backend \
./$branch/bin/llama-mtmd-cli --no-mmap -m $basedir/../gguf/$model \
--mmproj $basedir/../gguf/$mmproj \
--image $basedir/../gguf/$image \
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
--ctx-size 8192 --ubatch-size 256 -fa on \
-ngl 99 --device $device -v $cli_opts $@ \
"

View File

@ -49,5 +49,5 @@ $env:ADSP_LIBRARY_PATH="$basedir\lib"
& "$basedir\bin\llama-completion.exe" `
--no-mmap -no-cnv -m $basedir\..\..\gguf\$model `
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 `
--ctx-size 8192 --batch-size 128 -ctk q8_0 -ctv q8_0 -fa on `
--ctx-size 8192 --ubatch-size 128 -fa on `
-ngl 99 --device $device $cli_opts

View File

@ -5,7 +5,7 @@ import os
import sys
import subprocess
HTTPLIB_VERSION = "d4180e923f846b44a3d30acd938438d6e64fc9f6"
HTTPLIB_VERSION = "refs/tags/v0.34.0"
vendor = {
"https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp",

View File

@ -2440,64 +2440,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
// TODO: add more model-specific info which should prevent loading the session file if not identical
}
// write output ids
{
LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__);
const auto n_outputs = this->n_outputs;
const auto & output_ids = this->output_ids;
std::vector<int32_t> w_output_pos;
w_output_pos.resize(n_outputs);
// build a more compact representation of the output ids
for (size_t i = 0; i < n_batch(); ++i) {
// map an output id to a position in the batch
int64_t pos = output_ids[i];
if (pos >= 0) {
GGML_ASSERT(pos < n_outputs);
w_output_pos[pos] = i;
}
}
io.write(&n_outputs, sizeof(n_outputs));
if (n_outputs) {
io.write(w_output_pos.data(), n_outputs * sizeof(int32_t));
}
}
// [TAG_CONTEXT_STATE_LOGITS]
// write logits
{
LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__);
const uint64_t logits_size = std::min((uint64_t) this->logits.size, (uint64_t) n_outputs * model.vocab.n_tokens());
io.write(&logits_size, sizeof(logits_size));
if (logits_size) {
io.write(logits.data, logits_size * sizeof(float));
}
}
// write embeddings
{
LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__);
const uint64_t embd_size = std::min((uint64_t) this->embd.size, (uint64_t) n_outputs * model.hparams.n_embd);
io.write(&embd_size, sizeof(embd_size));
if (embd_size) {
io.write(embd.data, embd_size * sizeof(float));
}
}
// TODO: handle sampling buffers and samplers state ?
// https://github.com/ggml-org/llama.cpp/pull/17004
if (memory != nullptr) {
LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__);
memory->state_write(io);
@ -2523,70 +2465,6 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
// TODO: add more info which needs to be identical but which is not verified otherwise
}
// read output ids
{
LLAMA_LOG_DEBUG("%s: - reading output ids\n", __func__);
auto n_outputs = this->n_outputs;
io.read_to(&n_outputs, sizeof(n_outputs));
if (n_outputs > output_reserve(n_outputs)) {
throw std::runtime_error("could not reserve outputs");
}
std::vector<int32_t> output_pos;
if (n_outputs) {
output_pos.resize(n_outputs);
io.read_to(output_pos.data(), n_outputs * sizeof(int32_t));
for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
int32_t id = output_pos[i];
if ((uint32_t) id >= n_batch()) {
throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, n_batch()));
}
this->output_ids[id] = i;
}
this->n_outputs = n_outputs;
}
}
// read logits
{
LLAMA_LOG_DEBUG("%s: - reading logits\n", __func__);
uint64_t logits_size;
io.read_to(&logits_size, sizeof(logits_size));
if (this->logits.size < logits_size) {
throw std::runtime_error("logits buffer too small");
}
if (logits_size) {
io.read_to(this->logits.data, logits_size * sizeof(float));
}
}
// read embeddings
{
LLAMA_LOG_DEBUG("%s: - reading embeddings\n", __func__);
uint64_t embd_size;
io.read_to(&embd_size, sizeof(embd_size));
if (this->embd.size < embd_size) {
throw std::runtime_error("embeddings buffer too small");
}
if (embd_size) {
io.read_to(this->embd.data, embd_size * sizeof(float));
}
}
// TODO: handle sampling buffers and samplers state ?
// https://github.com/ggml-org/llama.cpp/pull/17004
if (memory) {
LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__);

View File

@ -1703,8 +1703,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
} break;
case LLM_ARCH_DEEPSEEK2:
{
// lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B
const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26);
// lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B, Kanana-2-30B-A3B
const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26 || (hparams.n_layer == 48 && n_vocab == 128256));
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);

View File

@ -2027,7 +2027,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2;
} else if (
tokenizer_pre == "gpt-4o" ||
tokenizer_pre == "llama4") {
tokenizer_pre == "llama4" ||
tokenizer_pre == "kanana2") {
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O;
clean_spaces = false;
} else if (

View File

@ -361,7 +361,7 @@ static void test_backend_temp_sampling(const test_params & params) {
GGML_ASSERT(false && "Failed to decode token");
}
// Verfify sequence 0
// Verify sequence 0
{
int32_t batch_idx = test_ctx.idx_for_seq(0);
int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
@ -379,7 +379,7 @@ static void test_backend_temp_sampling(const test_params & params) {
}
// Verfify sequence 1
// Verify sequence 1
{
int32_t batch_idx = test_ctx.idx_for_seq(1);
@ -395,7 +395,7 @@ static void test_backend_temp_sampling(const test_params & params) {
}
}
// lambda to testing non-positive temperature values.
// lambda for testing non-positive temperature values.
auto test_argmax_temp = [&](float temp) {
printf("\nTesting temperature = %.1f\n", temp);
@ -454,7 +454,7 @@ static void test_backend_temp_ext_sampling(const test_params & params) {
}
}
// lambda to testing non-positive temp/delta/exponent values.
// lambda for testing non-positive temp/delta/exponent values.
auto test_argmax_temp = [&](float temp, float delta, float exponent) {
printf("\nTesting temperature = %.1f, delta = %1.f, exponent = %1.f\n", temp, delta, exponent);
@ -530,7 +530,7 @@ static void test_backend_min_p_sampling(const test_params & params) {
printf("min-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str());
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
// Decode and sampler 10 more tokens
// Decode and sample 10 more tokens
for (int i = 0; i < 10; i++) {
int32_t loop_idx = test_ctx.idx_for_seq(seq_id);
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), loop_idx);
@ -582,7 +582,7 @@ static void test_backend_top_p_sampling(const test_params & params) {
printf("top-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str());
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
// Decode and sampler 10 more tokens
// Decode and sample 10 more tokens
for (int i = 0; i < 10; i++) {
int32_t loop_idx = test_ctx.idx_for_seq(seq_id);
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), loop_idx);
@ -619,7 +619,7 @@ static void test_backend_multi_sequence_sampling(const test_params & params) {
GGML_ASSERT(false && "Failed to decode token");
}
// Verfiy sequence 0
// Verify sequence 0
{
int32_t batch_idx = test_ctx.idx_for_seq(0);
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
@ -763,7 +763,7 @@ static void test_backend_logit_bias_sampling(const test_params & params) {
printf("backend logit bias sampling test PASSED\n");
}
// This test verifies that it is possible to have two different backend sampler,
// This test verifies that it is possible to have two different backend samplers,
// one that uses the backend dist sampler, and another that uses CPU dist sampler.
static void test_backend_mixed_sampling(const test_params & params) {
struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params();
@ -791,7 +791,7 @@ static void test_backend_mixed_sampling(const test_params & params) {
GGML_ASSERT(false && "Failed to decode token");
}
// Verfiy sequence 0 that used the dist backend sampler.
// Verify sequence 0 that used the dist backend sampler.
{
int32_t batch_idx = test_ctx.idx_for_seq(0);
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
@ -802,7 +802,7 @@ static void test_backend_mixed_sampling(const test_params & params) {
//GGML_ASSERT(llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx) == 0);
}
// Verfiy sequence 1 that used the top-k backend sampler.
// Verify sequence 1 that used the top-k backend sampler.
{
int32_t batch_idx = test_ctx.idx_for_seq(1);
float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx);
@ -934,7 +934,7 @@ static void test_backend_cpu_mixed_batch(const test_params & params) {
// samplers.
llama_set_sampler(test_ctx.ctx.get(), 0, nullptr);
// Create a CPU sampler and verify we can sampler from it.
// Create a CPU sampler and verify we can sample from it.
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
llama_sampler_chain_add(chain.get(), llama_sampler_init_greedy());

View File

@ -229,6 +229,20 @@ common_chat_tool python_tool {
"required": ["code"]
})",
};
common_chat_tool todo_list_tool {
/* .name = */ "todo_list",
/* .description = */ "Create or update the todo list",
/* .parameters = */ R"({
"type": "object",
"properties": {
"todos": {
"type": "array",
"description": "List of TODO list items"
}
},
"required": ["todos"]
})",
};
common_chat_tool code_interpreter_tool {
/* .name = */ "code_interpreter",
/* .description = */ "an ipython interpreter",
@ -3018,542 +3032,6 @@ Hey there!<|im_end|>
);
}
// Test Qwen3-Coder XML format
{
// Basic XML tool call parsing
assert_msg_equals(
message_assist_call,
test_chat_parse(
"<tool_call>\n"
" <function=special_function>\n"
" <parameter=arg1>\n"
" 1\n"
" </parameter>\n"
" </function>\n"
"</tool_call>",
/* is_partial= */ false,
{COMMON_CHAT_FORMAT_QWEN3_CODER_XML}));
// Multiple parameters with different types
common_chat_msg expected_multi_param;
expected_multi_param.role = "assistant";
expected_multi_param.tool_calls = {
{ "complex_function", "{\"name\":\"John Doe\",\"age\":30,\"active\":true,\"score\":95.5}", "" }
};
test_parser_with_streaming(expected_multi_param,
"<tool_call>\n"
" <function=complex_function>\n"
" <parameter=name>\n"
" John Doe\n"
" </parameter>\n"
" <parameter=age>\n"
" 30\n"
" </parameter>\n"
" <parameter=active>\n"
" true\n"
" </parameter>\n"
" <parameter=score>\n"
" 95.5\n"
" </parameter>\n"
" </function>\n"
"</tool_call>",
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
// Special characters and Unicode
common_chat_msg expected_special_chars;
expected_special_chars.role = "assistant";
expected_special_chars.tool_calls = {
{ "unicode_function", "{\"message\":\"Hello 世界! 🌍 Special chars: @#$%^&*()\"}", "" }
};
test_parser_with_streaming(expected_special_chars,
"<tool_call>\n"
" <function=unicode_function>\n"
" <parameter=message>\n"
" Hello 世界! 🌍 Special chars: @#$%^&*()\n"
" </parameter>\n"
" </function>\n"
"</tool_call>",
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
// Multiline content with newlines and indentation
common_chat_msg expected_multiline;
expected_multiline.role = "assistant";
expected_multiline.tool_calls = {
{ "code_function", "{\"code\":\"def hello():\\n print(\\\"Hello, World!\\\")\\n return True\"}", "" }
};
test_parser_with_streaming(expected_multiline,
"<tool_call>\n"
" <function=code_function>\n"
" <parameter=code>\n"
"def hello():\n"
" print(\"Hello, World!\")\n"
" return True\n"
" </parameter>\n"
" </function>\n"
"</tool_call>",
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
// JSON object as parameter value
common_chat_msg expected_json_param;
expected_json_param.role = "assistant";
expected_json_param.tool_calls = {
{ "json_function", "{\"config\":{\"host\":\"localhost\",\"port\":8080,\"ssl\":false}}", "" }
};
test_parser_with_streaming(
expected_json_param,
"<tool_call>\n"
" <function=json_function>\n"
" <parameter=config>\n"
" {\"host\": \"localhost\", \"port\": 8080, \"ssl\": false}\n"
" </parameter>\n"
" </function>\n"
"</tool_call>",
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
// Array as parameter value
common_chat_msg expected_array_param;
expected_array_param.role = "assistant";
expected_array_param.tool_calls = {
{ "array_function", "{\"items\":[\"apple\",\"banana\",\"cherry\"]}", "" }
};
test_parser_with_streaming(
expected_array_param,
"<tool_call>\n"
" <function=array_function>\n"
" <parameter=items>\n"
" [\"apple\", \"banana\", \"cherry\"]\n"
" </parameter>\n"
" </function>\n"
"</tool_call>",
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
// Empty parameter
common_chat_msg expected_empty_param;
expected_empty_param.role = "assistant";
expected_empty_param.tool_calls = {
{ "empty_function", "{\"empty_param\":\"\"}", "" }
};
test_parser_with_streaming(
expected_empty_param,
"<tool_call>\n"
" <function=empty_function>\n"
" <parameter=empty_param>\n"
" </parameter>\n"
" </function>\n"
"</tool_call>",
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
// Boolean values (true/false)
common_chat_msg expected_boolean;
expected_boolean.role = "assistant";
expected_boolean.tool_calls = {
{ "boolean_function", "{\"enabled\":true,\"debug\":false}", "" }
};
test_parser_with_streaming(
expected_boolean,
"<tool_call>\n"
" <function=boolean_function>\n"
" <parameter=enabled>\n"
" true\n"
" </parameter>\n"
" <parameter=debug>\n"
" false\n"
" </parameter>\n"
" </function>\n"
"</tool_call>",
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
// Null value
common_chat_msg expected_null;
expected_null.role = "assistant";
expected_null.tool_calls = {
{ "null_function", "{\"optional_param\":null}", "" }
};
test_parser_with_streaming(
expected_null,
"<tool_call>\n"
" <function=null_function>\n"
" <parameter=optional_param>\n"
" null\n"
" </parameter>\n"
" </function>\n"
"</tool_call>",
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
// Negative numbers and scientific notation
common_chat_msg expected_numbers;
expected_numbers.role = "assistant";
expected_numbers.tool_calls = {
{ "math_function", "{\"negative\":-42,\"decimal\":-3.14,\"scientific\":1.23e-4}", "" }
};
test_parser_with_streaming(
expected_numbers,
"<tool_call>\n"
" <function=math_function>\n"
" <parameter=negative>\n"
" -42\n"
" </parameter>\n"
" <parameter=decimal>\n"
" -3.14\n"
" </parameter>\n"
" <parameter=scientific>\n"
" 1.23e-4\n"
" </parameter>\n"
" </function>\n"
"</tool_call>",
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
// XML-like content in parameters (should be escaped)
common_chat_msg expected_xml_content;
expected_xml_content.role = "assistant";
expected_xml_content.tool_calls = {
{ "xml_function", "{\"xml_content\":\"<root><item>value</item></root>\"}", "" }
};
test_parser_with_streaming(
expected_xml_content,
"<tool_call>\n"
" <function=xml_function>\n"
" <parameter=xml_content>\n"
" <root><item>value</item></root>\n"
" </parameter>\n"
" </function>\n"
"</tool_call>",
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
// Quotes and escape characters
common_chat_msg expected_quotes;
expected_quotes.role = "assistant";
expected_quotes.tool_calls = {
{ "quote_function", "{\"message\":\"She said \\\"Hello!\\\" and left.\"}", "" }
};
test_parser_with_streaming(
expected_quotes,
"<tool_call>\n"
" <function=quote_function>\n"
" <parameter=message>\n"
" She said \"Hello!\" and left.\n"
" </parameter>\n"
" </function>\n"
"</tool_call>",
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
// Long parameter value (simplified)
std::string long_text = "This is a long text parameter that should test the parser's ability to handle larger amounts of text data.";
common_chat_msg expected_long_text;
expected_long_text.role = "assistant";
expected_long_text.tool_calls = {
{ "long_function", "{\"long_text\":\"" + long_text + "\"}", "" }
};
test_parser_with_streaming(
expected_long_text,
"<tool_call>\n"
" <function=long_function>\n"
" <parameter=long_text>\n"
" " + long_text + "\n"
" </parameter>\n"
" </function>\n"
"</tool_call>",
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
// Mixed content with text before and after tool call
common_chat_msg expected_mixed_content;
expected_mixed_content.role = "assistant";
expected_mixed_content.content = "I'll help you search for products. ";
expected_mixed_content.tool_calls = {
{ "search_function", "{\"query\":\"laptops\"}", "" }
};
test_parser_with_streaming(
expected_mixed_content,
"I'll help you search for products. <tool_call>\n"
" <function=search_function>\n"
" <parameter=query>\n"
" laptops\n"
" </parameter>\n"
" </function>\n"
"</tool_call>",
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
// Compact format (no extra whitespace)
common_chat_msg expected_compact;
expected_compact.role = "assistant";
expected_compact.tool_calls = {
{ "compact_function", "{\"param\":\"value\"}", "" }
};
test_parser_with_streaming(
expected_compact,
"<tool_call><function=compact_function><parameter=param>value</parameter></function></tool_call>",
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
// Function name with underscores and numbers
common_chat_msg expected_complex_name;
expected_complex_name.role = "assistant";
expected_complex_name.tool_calls = {
{ "get_user_data_v2", "{\"user_id\":12345}", "" }
};
test_parser_with_streaming(
expected_complex_name,
"<tool_call>\n"
" <function=get_user_data_v2>\n"
" <parameter=user_id>\n"
" 12345\n"
" </parameter>\n"
" </function>\n"
"</tool_call>",
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
// Parameter names with underscores and numbers
common_chat_msg expected_complex_params;
expected_complex_params.role = "assistant";
expected_complex_params.tool_calls = {
{ "test_function", "{\"param_1\":\"value1\",\"param_2_name\":\"value2\",\"param3\":123}", "" }
};
test_parser_with_streaming(
expected_complex_params,
"<tool_call>\n"
" <function=test_function>\n"
" <parameter=param_1>\n"
" value1\n"
" </parameter>\n"
" <parameter=param_2_name>\n"
" value2\n"
" </parameter>\n"
" <parameter=param3>\n"
" 123\n"
" </parameter>\n"
" </function>\n"
"</tool_call>",
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
// Very deeply nested XML content in parameter
common_chat_msg expected_deep_xml;
expected_deep_xml.role = "assistant";
expected_deep_xml.tool_calls = {
{ "xml_parser", "{\"xml\":\"<root><level1><level2><level3>deep content</level3></level2></level1></root>\"}", "" }
};
test_parser_with_streaming(
expected_deep_xml,
"<tool_call>\n"
" <function=xml_parser>\n"
" <parameter=xml>\n"
" <root><level1><level2><level3>deep content</level3></level2></level1></root>\n"
" </parameter>\n"
" </function>\n"
"</tool_call>",
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
// Parameter with only whitespace
common_chat_msg expected_whitespace_param;
expected_whitespace_param.role = "assistant";
expected_whitespace_param.tool_calls = {
{ "whitespace_function", "{\"spaces\":\"\"}", "" }
};
test_parser_with_streaming(
expected_whitespace_param,
"<tool_call>\n"
" <function=whitespace_function>\n"
" <parameter=spaces>\n"
" \n"
" </parameter>\n"
" </function>\n"
"</tool_call>",
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
// Parameter with tabs and mixed whitespace
common_chat_msg expected_mixed_whitespace;
expected_mixed_whitespace.role = "assistant";
expected_mixed_whitespace.tool_calls = {
{ "tab_function", "{\"content\":\"line1\\n\\tindented line\\n spaces\"}", "" }
};
test_parser_with_streaming(
expected_mixed_whitespace,
"<tool_call>\n"
" <function=tab_function>\n"
" <parameter=content>\n"
"line1\n"
"\tindented line\n"
" spaces\n"
" </parameter>\n"
" </function>\n"
"</tool_call>",
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
// Control characters and special Unicode
common_chat_msg expected_control_chars;
expected_control_chars.role = "assistant";
expected_control_chars.tool_calls = {
{ "control_function", "{\"text\":\"Line1\\nLine2\\tTabbed\\rCarriage return\"}", "" }
};
test_parser_with_streaming(
expected_control_chars,
"<tool_call>\n"
" <function=control_function>\n"
" <parameter=text>\n"
"Line1\nLine2\tTabbed\rCarriage return\n"
" </parameter>\n"
" </function>\n"
"</tool_call>",
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
// Emoji and extended Unicode characters
common_chat_msg expected_emoji;
expected_emoji.role = "assistant";
expected_emoji.tool_calls = {
{ "emoji_function", "{\"message\":\"Hello! 👋 🌟 🚀 Testing emojis: 😀😃😄😁 and symbols: ∑∏∆∇\"}", "" }
};
test_parser_with_streaming(
expected_emoji,
"<tool_call>\n"
" <function=emoji_function>\n"
" <parameter=message>\n"
" Hello! 👋 🌟 🚀 Testing emojis: 😀😃😄😁 and symbols: ∑∏∆∇\n"
" </parameter>\n"
" </function>\n"
"</tool_call>",
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
// Mathematical expressions and formulas
common_chat_msg expected_math;
expected_math.role = "assistant";
expected_math.tool_calls = {
{ "math_function", "{\"formula\":\"E = mc² and ∫f(x)dx = F(x) + C\"}", "" }
};
test_parser_with_streaming(
expected_math,
"<tool_call>\n"
" <function=math_function>\n"
" <parameter=formula>\n"
" E = mc² and ∫f(x)dx = F(x) + C\n"
" </parameter>\n"
" </function>\n"
"</tool_call>",
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
// SQL injection-like content (should be safely escaped)
common_chat_msg expected_sql;
expected_sql.role = "assistant";
expected_sql.tool_calls = {
{ "sql_function", "{\"query\":\"SELECT * FROM users WHERE id = 1; DROP TABLE users; --\"}", "" }
};
test_parser_with_streaming(
expected_sql,
"<tool_call>\n"
" <function=sql_function>\n"
" <parameter=query>\n"
" SELECT * FROM users WHERE id = 1; DROP TABLE users; --\n"
" </parameter>\n"
" </function>\n"
"</tool_call>",
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
// HTML/XML injection content
common_chat_msg expected_html;
expected_html.role = "assistant";
expected_html.tool_calls = {
{ "html_function", "{\"content\":\"<script>alert('xss')</script><img src=x onerror=alert(1)>\"}", "" }
};
test_parser_with_streaming(
expected_html,
"<tool_call>\n"
" <function=html_function>\n"
" <parameter=content>\n"
" <script>alert('xss')</script><img src=x onerror=alert(1)>\n"
" </parameter>\n"
" </function>\n"
"</tool_call>",
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
// Binary-like content (base64)
common_chat_msg expected_binary;
expected_binary.role = "assistant";
expected_binary.tool_calls = {
{ "binary_function", "{\"data\":\"SGVsbG8gV29ybGQhIFRoaXMgaXMgYmFzZTY0IGVuY29kZWQgdGV4dC4=\"}", "" }
};
test_parser_with_streaming(
expected_binary,
"<tool_call>\n"
" <function=binary_function>\n"
" <parameter=data>\n"
" SGVsbG8gV29ybGQhIFRoaXMgaXMgYmFzZTY0IGVuY29kZWQgdGV4dC4=\n"
" </parameter>\n"
" </function>\n"
"</tool_call>",
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
// Very large numbers (should be parsed as scientific notation)
common_chat_msg expected_large_numbers;
expected_large_numbers.role = "assistant";
expected_large_numbers.tool_calls = {
{ "number_function", "{\"big_int\":1e+60}", "" } // Large number becomes scientific notation
};
test_parser_with_streaming(
expected_large_numbers,
"<tool_call>\n"
" <function=number_function>\n"
" <parameter=big_int>\n"
" 999999999999999999999999999999999999999999999999999999999999\n"
" </parameter>\n"
" </function>\n"
"</tool_call>",
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
}
{
// Qwen3-Coder template
auto tmpls = read_templates("models/templates/Qwen3-Coder.jinja");
common_chat_templates_inputs inputs;
inputs.messages = { message_user };
common_chat_tool qwen_union_tool {
/* .name = */ "qwen_union",
/* .description = */ "Test tool for union/anyOf handling",
/* .parameters = */ R"({
"type": "object",
"properties": {
"priority": { "type": ["number", "null"] },
"maybe_text": { "anyOf": [ { "type": "string" } ] },
"config": { "anyOf": [ { "type": "object" }, { "type": "null" } ] }
},
"required": []
})",
};
inputs.tools = { qwen_union_tool };
auto params = common_chat_templates_apply(tmpls.get(), inputs);
assert_equals(COMMON_CHAT_FORMAT_QWEN3_CODER_XML, params.format);
assert_equals(false, params.grammar.empty());
// Grammar should compile successfully
auto grammar = build_grammar(params.grammar);
GGML_ASSERT(grammar && "Failed to build Qwen3-Coder grammar with union types");
}
{
// Step-3.5-Flash template: uses same XML output format as Qwen3-Coder and Nemotron v3,
// but with <think> support. Routes to the Nemotron v3 PEG parser for streaming and
@ -3665,6 +3143,135 @@ static void test_template_output_peg_parsers() {
});
}
{
// Qwen3-Coder
auto tmpls = read_templates("models/templates/Qwen3-Coder.jinja");
// Test basic message
test_peg_parser(tmpls.get(), [&](auto & t) {
t.input = "Hello, world!\nWhat's up?";
t.expect = message_assist;
});
// Test tool call
test_peg_parser(tmpls.get(), [&](auto & t) {
t.input =
"<tool_call>\n"
"<function=special_function>\n"
"<parameter=arg1>\n"
"1\n"
"</parameter>\n"
"</function>\n"
"</tool_call>";
t.params.tools = {special_function_tool};
t.expect = message_assist_call;
});
// Test parallel tool calls
test_peg_parser(tmpls.get(), [&](auto & t) {
t.input =
"<tool_call>\n"
"<function=special_function>\n"
"<parameter=arg1>\n"
"1\n"
"</parameter>\n"
"</function>\n"
"</tool_call>\n"
"<tool_call>\n"
"<function=special_function_with_opt>\n"
"<parameter=arg1>\n"
"1\n"
"</parameter>\n"
"<parameter=arg2>\n"
"2\n"
"</parameter>\n"
"</function>\n"
"</tool_call>";
t.params.parallel_tool_calls = true;
t.params.tools = {special_function_tool, special_function_tool_with_optional_param};
t.expect.tool_calls = {{
/* .name = */ "special_function",
/* .arguments = */ R"({"arg1": 1})",
/* .id = */ {},
}, {
/* .name = */ "special_function_with_opt",
/* .arguments = */ R"({"arg1": 1, "arg2": 2})",
/* .id = */ {},
}};
});
// Test tool call with string parameter
test_peg_parser(tmpls.get(), [&](auto & t) {
t.input =
"<tool_call>\n"
"<function=python>\n"
"<parameter=code>\n"
"def hello():\n"
" print(\"Hello, world!\")\n"
"\n"
"hello()\n"
"</parameter>\n"
"</function>\n"
"</tool_call>";
t.params.tools = {python_tool};
t.expect.tool_calls = {{
/* .name = */ "python",
/* .arguments = */ "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}",
/* .id = */ {},
}};
});
// Test tool call with JSON parameter
test_peg_parser(tmpls.get(), [&](auto & t) {
t.input =
"<tool_call>\n"
"<function=todo_list>\n"
"<parameter=todos>\n"
"[{\"item\": \"Check stuff\", \"selected\": false}, {\"item\": \"Prepare stuff\", \"selected\": true}]\n"
"</parameter>\n"
"</function>\n"
"</tool_call>";
t.params.tools = {todo_list_tool};
t.expect.tool_calls = {{
/* .name = */ "todo_list",
/* .arguments = */ "{\"todos\": [{\"item\": \"Check stuff\", \"selected\": false}, {\"item\": \"Prepare stuff\", \"selected\": true}]}",
/* .id = */ {},
}};
});
// Test tool call with string parameter and no closing </parameter> tag
test_peg_parser(tmpls.get(), [&](auto & t) {
t.input =
"<tool_call>\n"
"<function=python>\n"
"<parameter=code>\n"
"def hello():\n"
" print(\"Hello, world!\")\n"
"\n"
"hello()\n"
"</function>\n"
"</tool_call>";
t.params.tools = {python_tool};
t.expect.tool_calls = {{
/* .name = */ "python",
/* .arguments = */ "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}",
/* .id = */ {},
}};
});
// Test response format
test_peg_parser(tmpls.get(), [&](auto & t) {
t.input = R"({"amount": 123.45, "date": "2025-12-03"})";
t.params.json_schema = invoice_schema;
t.expect.content = R"({"amount": 123.45, "date": "2025-12-03"})";
});
}
{
// NVIDIA Nemotron-3 Nano
auto tmpls = read_templates("models/templates/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16.jinja");

View File

@ -32,6 +32,7 @@ static void test_string_methods(testing & t);
static void test_array_methods(testing & t);
static void test_object_methods(testing & t);
static void test_hasher(testing & t);
static void test_stats(testing & t);
static void test_fuzzing(testing & t);
static bool g_python_mode = false;
@ -70,6 +71,7 @@ int main(int argc, char *argv[]) {
t.test("object methods", test_object_methods);
if (!g_python_mode) {
t.test("hasher", test_hasher);
t.test("stats", test_stats);
t.test("fuzzing", test_fuzzing);
}
@ -1795,6 +1797,63 @@ static void test_hasher(testing & t) {
});
}
static void test_stats(testing & t) {
static auto get_stats = [](const std::string & tmpl, const json & vars) -> jinja::value {
jinja::lexer lexer;
auto lexer_res = lexer.tokenize(tmpl);
jinja::program prog = jinja::parse_from_tokens(lexer_res);
jinja::context ctx(tmpl);
jinja::global_from_json(ctx, json{{ "val", vars }}, true);
ctx.is_get_stats = true;
jinja::runtime runtime(ctx);
runtime.execute(prog);
return ctx.get_val("val");
};
t.test("stats", [](testing & t) {
jinja::value val = get_stats(
"{{val.num}} "
"{{val.str}} "
"{{val.arr[0]}} "
"{{val.obj.key1}} "
"{{val.nested | tojson}}",
// Note: the json below will be wrapped inside "val" in the context
json{
{"num", 1},
{"str", "abc"},
{"arr", json::array({1, 2, 3})},
{"obj", json::object({{"key1", 1}, {"key2", 2}, {"key3", 3}})},
{"nested", json::object({
{"inner_key1", json::array({1, 2})},
{"inner_key2", json::object({{"a", "x"}, {"b", "y"}})}
})},
{"mixed", json::object({
{"used", 1},
{"unused", 2},
})},
}
);
t.assert_true("num is used", val->at("num")->stats.used);
t.assert_true("str is used", val->at("str")->stats.used);
t.assert_true("arr is used", val->at("arr")->stats.used);
t.assert_true("arr[0] is used", val->at("arr")->at(0)->stats.used);
t.assert_true("arr[1] is not used", !val->at("arr")->at(1)->stats.used);
t.assert_true("obj is used", val->at("obj")->stats.used);
t.assert_true("obj.key1 is used", val->at("obj")->at("key1")->stats.used);
t.assert_true("obj.key2 is not used", !val->at("obj")->at("key2")->stats.used);
t.assert_true("inner_key1[0] is used", val->at("nested")->at("inner_key1")->at(0)->stats.used);
t.assert_true("inner_key2.a is used", val->at("nested")->at("inner_key2")->at("a")->stats.used);
});
}
static void test_template_cpp(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) {
t.test(name, [&tmpl, &vars, &expect](testing & t) {
jinja::lexer lexer;

View File

@ -380,6 +380,15 @@ int main(int argc, char ** argv) {
console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
continue;
}
if (inf.fim_sep_token != LLAMA_TOKEN_NULL) {
cur_msg += common_token_to_piece(ctx_cli.ctx_server.get_llama_context(), inf.fim_sep_token, true);
cur_msg += fname;
cur_msg.push_back('\n');
} else {
cur_msg += "--- File: ";
cur_msg += fname;
cur_msg += " ---\n";
}
cur_msg += marker;
console::log("Loaded text from '%s'\n", fname.c_str());
continue;

View File

@ -387,6 +387,17 @@ int main(int argc, char ** argv) {
}
session_do_save = !path_session.empty() && n_match < embd_inp.size() && !params.prompt_cache_ro;
// Logits are not stored as part of the session state so we need to
// "replay" the last token to get logits for sampling.
if (!session_tokens.empty() && n_match > 0 && n_match == session_tokens.size()) {
if (!common_replay_last_token(ctx, session_tokens.back(), n_match)) {
return 1;
}
session_do_save = false;
LOG_INF("%s: replayed last token from session\n", __func__);
}
}
// number of tokens to keep when resetting context
@ -675,40 +686,27 @@ int main(int argc, char ** argv) {
}
if (!embd.empty()) {
int n_eval = (int) embd.size();
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
GGML_ASSERT(n_eval <= params.n_batch);
if (llama_decode(ctx, llama_batch_get_one(embd.data(), n_eval))) {
LOG_ERR("%s : failed to eval\n", __func__);
const bool is_last_batch = (n_consumed >= (int) embd_inp.size());
const bool save_now = session_do_save && is_last_batch;
if (!common_prompt_batch_decode(ctx, embd, n_past, params.n_batch, path_session, save_now)) {
return 1;
}
n_past += n_eval;
session_tokens.insert(session_tokens.end(), embd.begin(), embd.begin());
n_session_consumed = session_tokens.size();
session_do_save = false;
LOG_DBG("n_past = %d\n", n_past);
// Display total tokens alongside total time
if (params.n_print > 0 && n_past % params.n_print == 0) {
LOG_DBG("\n\033[31mTokens consumed so far = %d / %d \033[0m\n", n_past, n_ctx);
}
}
if (!embd.empty() && !path_session.empty()) {
session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
n_session_consumed = session_tokens.size();
}
}
embd.clear();
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
// optionally save the session on first sample (for faster prompt loading next time)
if (session_do_save) {
session_do_save = false;
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
LOG_DBG("saved session to %s\n", path_session.c_str());
}
const llama_token id = common_sampler_sample(smpl, ctx, -1);

Binary file not shown.

View File

@ -1105,6 +1105,8 @@ json convert_responses_to_chatcmpl(const json & response_body) {
};
for (json item : input_value) {
bool merge_prev = !chatcmpl_messages.empty() && chatcmpl_messages.back().value("role", "") == "assistant";
if (exists_and_is_string(item, "content")) {
// #responses_create-input-input_item_list-input_message-content-text_input
// Only "Input message" contains item["content"]::string
@ -1193,7 +1195,7 @@ json convert_responses_to_chatcmpl(const json & response_body) {
item.at("type") == "message"
) {
// #responses_create-input-input_item_list-item-output_message
std::vector<json> chatcmpl_content;
auto chatcmpl_content = json::array();
for (const auto & output_text : item.at("content")) {
const std::string type = json_value(output_text, "type", std::string());
@ -1210,10 +1212,19 @@ json convert_responses_to_chatcmpl(const json & response_body) {
});
}
item.erase("status");
item.erase("type");
item["content"] = chatcmpl_content;
chatcmpl_messages.push_back(item);
if (merge_prev) {
auto & prev_msg = chatcmpl_messages.back();
if (!exists_and_is_array(prev_msg, "content")) {
prev_msg["content"] = json::array();
}
auto & prev_content = prev_msg["content"];
prev_content.insert(prev_content.end(), chatcmpl_content.begin(), chatcmpl_content.end());
} else {
item.erase("status");
item.erase("type");
item["content"] = chatcmpl_content;
chatcmpl_messages.push_back(item);
}
} else if (exists_and_is_string(item, "arguments") &&
exists_and_is_string(item, "call_id") &&
exists_and_is_string(item, "name") &&
@ -1221,24 +1232,27 @@ json convert_responses_to_chatcmpl(const json & response_body) {
item.at("type") == "function_call"
) {
// #responses_create-input-input_item_list-item-function_tool_call
json msg = json {
{"role", "assistant"},
{"tool_calls", json::array({ json {
{"function", json {
{"arguments", item.at("arguments")},
{"name", item.at("name")},
}},
{"id", item.at("call_id")},
{"type", "function"},
}})},
json tool_call = {
{"function", json {
{"arguments", item.at("arguments")},
{"name", item.at("name")},
}},
{"id", item.at("call_id")},
{"type", "function"},
};
if (!chatcmpl_messages.empty() && chatcmpl_messages.back().contains("reasoning_content")) {
// Move reasoning content from dummy message to tool call message
msg["reasoning_content"] = chatcmpl_messages.back().at("reasoning_content");
chatcmpl_messages.pop_back();
if (merge_prev) {
auto & prev_msg = chatcmpl_messages.back();
if (!exists_and_is_array(prev_msg, "tool_calls")) {
prev_msg["tool_calls"] = json::array();
}
prev_msg["tool_calls"].push_back(tool_call);
} else {
chatcmpl_messages.push_back(json {
{"role", "assistant"},
{"tool_calls", json::array({tool_call})}
});
}
chatcmpl_messages.push_back(msg);
} else if (exists_and_is_string(item, "call_id") &&
(exists_and_is_string(item, "output") || exists_and_is_array(item, "output")) &&
exists_and_is_string(item, "type") &&
@ -1282,12 +1296,16 @@ json convert_responses_to_chatcmpl(const json & response_body) {
throw std::invalid_argument("item['content']['text'] is not a string");
}
// Pack reasoning content in dummy message
chatcmpl_messages.push_back(json {
{"role", "assistant"},
{"content", json::array()},
{"reasoning_content", item.at("content")[0].at("text")},
});
if (merge_prev) {
auto & prev_msg = chatcmpl_messages.back();
prev_msg["reasoning_content"] = item.at("content")[0].at("text");
} else {
chatcmpl_messages.push_back(json {
{"role", "assistant"},
{"content", json::array()},
{"reasoning_content", item.at("content")[0].at("text")},
});
}
} else {
throw std::invalid_argument("Cannot determine type of 'item'");
}
@ -1296,20 +1314,6 @@ json convert_responses_to_chatcmpl(const json & response_body) {
throw std::invalid_argument("'input' must be a string or array of objects");
}
// Remove unused dummy message which contains
// reasoning content not followed by tool call
chatcmpl_messages.erase(std::remove_if(
chatcmpl_messages.begin(),
chatcmpl_messages.end(),
[](const json & x){ return x.contains("role") &&
x.at("role") == "assistant" &&
x.contains("content") &&
x.at("content") == json::array() &&
x.contains("reasoning_content");
}),
chatcmpl_messages.end()
);
chatcmpl_body["messages"] = chatcmpl_messages;
if (response_body.contains("tools")) {

View File

@ -2911,6 +2911,9 @@ server_context_meta server_context::get_meta() const {
/* fim_pre_token */ llama_vocab_fim_pre(impl->vocab),
/* fim_sub_token */ llama_vocab_fim_suf(impl->vocab),
/* fim_mid_token */ llama_vocab_fim_mid(impl->vocab),
/* fim_pad_token */ llama_vocab_fim_pad(impl->vocab),
/* fim_rep_token */ llama_vocab_fim_rep(impl->vocab),
/* fim_sep_token */ llama_vocab_fim_sep(impl->vocab),
/* model_vocab_type */ llama_vocab_type(impl->vocab),
/* model_vocab_n_tokens */ llama_vocab_n_tokens(impl->vocab),

View File

@ -30,6 +30,9 @@ struct server_context_meta {
llama_token fim_pre_token;
llama_token fim_sub_token;
llama_token fim_mid_token;
llama_token fim_pad_token;
llama_token fim_rep_token;
llama_token fim_sep_token;
// model meta
enum llama_vocab_type model_vocab_type;

View File

@ -101,7 +101,7 @@ In a separate terminal, start the backend server:
./llama-server -m model.gguf
# Multi-model (ROUTER mode)
./llama-server --model-store /path/to/models
./llama-server --models-dir /path/to/models
```
### 3. Start Development Servers

View File

@ -114,6 +114,11 @@
label: 'Render user content as Markdown',
type: SettingsFieldType.CHECKBOX
},
{
key: SETTINGS_KEYS.FULL_HEIGHT_CODE_BLOCKS,
label: 'Use full height code blocks',
type: SettingsFieldType.CHECKBOX
},
{
key: SETTINGS_KEYS.DISABLE_AUTO_SCROLL,
label: 'Disable automatic scroll',

View File

@ -38,6 +38,8 @@
import { ActionIconsCodeBlock, DialogCodePreview } from '$lib/components/app';
import { createAutoScrollController } from '$lib/hooks/use-auto-scroll.svelte';
import type { DatabaseMessageExtra } from '$lib/types/database';
import { config } from '$lib/stores/settings.svelte';
import { SETTINGS_KEYS } from '$lib/constants/settings-keys';
interface Props {
attachments?: DatabaseMessageExtra[];
@ -593,7 +595,12 @@
});
</script>
<div bind:this={containerRef} class={className}>
<div
bind:this={containerRef}
class="{className}{config()[SETTINGS_KEYS.FULL_HEIGHT_CODE_BLOCKS]
? ' full-height-code-blocks'
: ''}"
>
{#each renderedBlocks as block (block.id)}
<div class="markdown-block" data-block-id={block.id}>
<!-- eslint-disable-next-line no-at-html-tags -->
@ -914,6 +921,16 @@
line-height: 1.3;
}
.full-height-code-blocks :global(.code-block-wrapper) {
max-height: none;
}
.full-height-code-blocks :global(.code-block-scroll-container),
.full-height-code-blocks .streaming-code-scroll-container {
max-height: none;
overflow-y: visible;
}
div :global(.code-block-header) {
display: flex;
justify-content: space-between;

View File

@ -251,9 +251,6 @@
return options.find((option) => option.id === activeId);
}
if (options.length === 1) {
return options[0];
}
// No selection - return undefined to show "Select model"
return undefined;
}

View File

@ -22,6 +22,7 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean> =
alwaysShowSidebarOnDesktop: false,
autoShowSidebarOnNewChat: true,
autoMicOnEmpty: false,
fullHeightCodeBlocks: false,
// make sure these default values are in sync with `common.h`
samplers: 'top_k;typ_p;top_p;min_p;temperature',
backend_sampling: false,
@ -113,6 +114,8 @@ export const SETTING_CONFIG_INFO: Record<string, string> = {
'Automatically show sidebar when starting a new chat. Disable to keep the sidebar hidden until you click on it.',
autoMicOnEmpty:
'Automatically show microphone button instead of send button when textarea is empty for models with audio modality support.',
fullHeightCodeBlocks:
'Always display code blocks at their full natural height, overriding any height limits.',
pyInterpreterEnabled:
'Enable Python interpreter using Pyodide. Allows running Python code in markdown code blocks.',
enableContinueGeneration:

View File

@ -23,6 +23,7 @@ export const SETTINGS_KEYS = {
DISABLE_AUTO_SCROLL: 'disableAutoScroll',
ALWAYS_SHOW_SIDEBAR_ON_DESKTOP: 'alwaysShowSidebarOnDesktop',
AUTO_SHOW_SIDEBAR_ON_NEW_CHAT: 'autoShowSidebarOnNewChat',
FULL_HEIGHT_CODE_BLOCKS: 'fullHeightCodeBlocks',
// Sampling
TEMPERATURE: 'temperature',
DYNATEMP_RANGE: 'dynatemp_range',

View File

@ -153,6 +153,12 @@ export const SYNCABLE_PARAMETERS: SyncableParameter[] = [
serverKey: 'enableContinueGeneration',
type: SyncableParameterType.BOOLEAN,
canSync: true
},
{
key: 'fullHeightCodeBlocks',
serverKey: 'fullHeightCodeBlocks',
type: SyncableParameterType.BOOLEAN,
canSync: true
}
];

View File

@ -306,6 +306,16 @@ class ModelsStore {
const response = await ModelsService.listRouter();
this.routerModels = response.data;
await this.fetchModalitiesForLoadedModels();
const o = this.models.filter((option) => {
const modelProps = this.getModelProps(option.model);
return modelProps?.webui !== false;
});
if (o.length === 1 && this.isModelLoaded(o[0].model)) {
this.selectModelById(o[0].id);
}
} catch (error) {
console.warn('Failed to fetch router models:', error);
this.routerModels = [];

File diff suppressed because it is too large Load Diff

View File

@ -8,8 +8,8 @@
#ifndef CPPHTTPLIB_HTTPLIB_H
#define CPPHTTPLIB_HTTPLIB_H
#define CPPHTTPLIB_VERSION "0.32.0"
#define CPPHTTPLIB_VERSION_NUM "0x002000"
#define CPPHTTPLIB_VERSION "0.34.0"
#define CPPHTTPLIB_VERSION_NUM "0x002200"
/*
* Platform compatibility check
@ -185,6 +185,14 @@
: 0))
#endif
#ifndef CPPHTTPLIB_THREAD_POOL_MAX_COUNT
#define CPPHTTPLIB_THREAD_POOL_MAX_COUNT (CPPHTTPLIB_THREAD_POOL_COUNT * 4)
#endif
#ifndef CPPHTTPLIB_THREAD_POOL_IDLE_TIMEOUT
#define CPPHTTPLIB_THREAD_POOL_IDLE_TIMEOUT 3 // seconds
#endif
#ifndef CPPHTTPLIB_RECV_FLAGS
#define CPPHTTPLIB_RECV_FLAGS 0
#endif
@ -201,6 +209,22 @@
#define CPPHTTPLIB_MAX_LINE_LENGTH 32768
#endif
#ifndef CPPHTTPLIB_WEBSOCKET_MAX_PAYLOAD_LENGTH
#define CPPHTTPLIB_WEBSOCKET_MAX_PAYLOAD_LENGTH 16777216
#endif
#ifndef CPPHTTPLIB_WEBSOCKET_READ_TIMEOUT_SECOND
#define CPPHTTPLIB_WEBSOCKET_READ_TIMEOUT_SECOND 300
#endif
#ifndef CPPHTTPLIB_WEBSOCKET_CLOSE_TIMEOUT_SECOND
#define CPPHTTPLIB_WEBSOCKET_CLOSE_TIMEOUT_SECOND 5
#endif
#ifndef CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND
#define CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND 30
#endif
/*
* Headers
*/
@ -310,6 +334,7 @@ using socket_t = int;
#include <errno.h>
#include <exception>
#include <fcntl.h>
#include <fstream>
#include <functional>
#include <iomanip>
#include <iostream>
@ -328,6 +353,9 @@ using socket_t = int;
#include <unordered_map>
#include <unordered_set>
#include <utility>
#if __cplusplus >= 201703L
#include <any>
#endif
#if defined(CPPHTTPLIB_USE_NON_BLOCKING_GETADDRINFO) || \
defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN)
@ -415,10 +443,46 @@ using socket_t = int;
#endif // CPPHTTPLIB_MBEDTLS_SUPPORT
#ifdef CPPHTTPLIB_WOLFSSL_SUPPORT
#include <wolfssl/options.h>
#include <wolfssl/openssl/x509v3.h>
// Fallback definitions for older wolfSSL versions (e.g., 5.6.6)
#ifndef WOLFSSL_GEN_EMAIL
#define WOLFSSL_GEN_EMAIL 1
#endif
#ifndef WOLFSSL_GEN_DNS
#define WOLFSSL_GEN_DNS 2
#endif
#ifndef WOLFSSL_GEN_URI
#define WOLFSSL_GEN_URI 6
#endif
#ifndef WOLFSSL_GEN_IPADD
#define WOLFSSL_GEN_IPADD 7
#endif
#include <wolfssl/ssl.h>
#include <wolfssl/wolfcrypt/hash.h>
#include <wolfssl/wolfcrypt/md5.h>
#include <wolfssl/wolfcrypt/sha256.h>
#include <wolfssl/wolfcrypt/sha512.h>
#ifdef _WIN32
#include <wincrypt.h>
#ifdef _MSC_VER
#pragma comment(lib, "crypt32.lib")
#endif
#endif // _WIN32
#if defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN)
#if TARGET_OS_MAC
#include <Security/Security.h>
#endif
#endif // CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN
#endif // CPPHTTPLIB_WOLFSSL_SUPPORT
// Define CPPHTTPLIB_SSL_ENABLED if any SSL backend is available
// This simplifies conditional compilation when adding new backends (e.g.,
// wolfSSL)
#if defined(CPPHTTPLIB_OPENSSL_SUPPORT) || defined(CPPHTTPLIB_MBEDTLS_SUPPORT)
#if defined(CPPHTTPLIB_OPENSSL_SUPPORT) || \
defined(CPPHTTPLIB_MBEDTLS_SUPPORT) || defined(CPPHTTPLIB_WOLFSSL_SUPPORT)
#define CPPHTTPLIB_SSL_ENABLED
#endif
@ -440,6 +504,10 @@ using socket_t = int;
*/
namespace httplib {
namespace ws {
class WebSocket;
} // namespace ws
namespace detail {
/*
@ -711,6 +779,143 @@ using Match = std::smatch;
using DownloadProgress = std::function<bool(size_t current, size_t total)>;
using UploadProgress = std::function<bool(size_t current, size_t total)>;
#if __cplusplus >= 201703L
using any = std::any;
using bad_any_cast = std::bad_any_cast;
template <typename T> T any_cast(const any &a) { return std::any_cast<T>(a); }
template <typename T> T any_cast(any &a) { return std::any_cast<T>(a); }
template <typename T> T any_cast(any &&a) {
return std::any_cast<T>(std::move(a));
}
template <typename T> const T *any_cast(const any *a) noexcept {
return std::any_cast<T>(a);
}
template <typename T> T *any_cast(any *a) noexcept {
return std::any_cast<T>(a);
}
#else // C++11/14 implementation
class bad_any_cast : public std::bad_cast {
public:
const char *what() const noexcept override { return "bad any_cast"; }
};
namespace detail {
using any_type_id = const void *;
// Returns a unique per-type ID without RTTI.
// The static address is stable across TUs because function templates are
// implicitly inline and the ODR merges their statics into one.
template <typename T> any_type_id any_typeid() noexcept {
static const char id = 0;
return &id;
}
struct any_storage {
virtual ~any_storage() = default;
virtual std::unique_ptr<any_storage> clone() const = 0;
virtual any_type_id type_id() const noexcept = 0;
};
template <typename T> struct any_value final : any_storage {
T value;
template <typename U> explicit any_value(U &&v) : value(std::forward<U>(v)) {}
std::unique_ptr<any_storage> clone() const override {
return std::unique_ptr<any_storage>(new any_value<T>(value));
}
any_type_id type_id() const noexcept override { return any_typeid<T>(); }
};
} // namespace detail
class any {
std::unique_ptr<detail::any_storage> storage_;
public:
any() noexcept = default;
any(const any &o) : storage_(o.storage_ ? o.storage_->clone() : nullptr) {}
any(any &&) noexcept = default;
any &operator=(const any &o) {
storage_ = o.storage_ ? o.storage_->clone() : nullptr;
return *this;
}
any &operator=(any &&) noexcept = default;
template <
typename T, typename D = typename std::decay<T>::type,
typename std::enable_if<!std::is_same<D, any>::value, int>::type = 0>
any(T &&v) : storage_(new detail::any_value<D>(std::forward<T>(v))) {}
template <
typename T, typename D = typename std::decay<T>::type,
typename std::enable_if<!std::is_same<D, any>::value, int>::type = 0>
any &operator=(T &&v) {
storage_.reset(new detail::any_value<D>(std::forward<T>(v)));
return *this;
}
bool has_value() const noexcept { return storage_ != nullptr; }
void reset() noexcept { storage_.reset(); }
template <typename T> friend T *any_cast(any *a) noexcept;
template <typename T> friend const T *any_cast(const any *a) noexcept;
};
template <typename T> T *any_cast(any *a) noexcept {
if (!a || !a->storage_) { return nullptr; }
if (a->storage_->type_id() != detail::any_typeid<T>()) { return nullptr; }
return &static_cast<detail::any_value<T> *>(a->storage_.get())->value;
}
template <typename T> const T *any_cast(const any *a) noexcept {
if (!a || !a->storage_) { return nullptr; }
if (a->storage_->type_id() != detail::any_typeid<T>()) { return nullptr; }
return &static_cast<const detail::any_value<T> *>(a->storage_.get())->value;
}
template <typename T> T any_cast(const any &a) {
using U =
typename std::remove_cv<typename std::remove_reference<T>::type>::type;
const U *p = any_cast<U>(&a);
#ifndef CPPHTTPLIB_NO_EXCEPTIONS
if (!p) { throw bad_any_cast{}; }
#else
if (!p) { std::abort(); }
#endif
return static_cast<T>(*p);
}
template <typename T> T any_cast(any &a) {
using U =
typename std::remove_cv<typename std::remove_reference<T>::type>::type;
U *p = any_cast<U>(&a);
#ifndef CPPHTTPLIB_NO_EXCEPTIONS
if (!p) { throw bad_any_cast{}; }
#else
if (!p) { std::abort(); }
#endif
return static_cast<T>(*p);
}
template <typename T> T any_cast(any &&a) {
using U =
typename std::remove_cv<typename std::remove_reference<T>::type>::type;
U *p = any_cast<U>(&a);
#ifndef CPPHTTPLIB_NO_EXCEPTIONS
if (!p) { throw bad_any_cast{}; }
#else
if (!p) { std::abort(); }
#endif
return static_cast<T>(std::move(*p));
}
#endif // __cplusplus >= 201703L
struct Response;
using ResponseHandler = std::function<bool(const Response &response)>;
@ -805,6 +1010,60 @@ struct FormDataProvider {
};
using FormDataProviderItems = std::vector<FormDataProvider>;
inline FormDataProvider
make_file_provider(const std::string &name, const std::string &filepath,
const std::string &filename = std::string(),
const std::string &content_type = std::string()) {
FormDataProvider fdp;
fdp.name = name;
fdp.filename = filename.empty() ? filepath : filename;
fdp.content_type = content_type;
fdp.provider = [filepath](size_t offset, DataSink &sink) -> bool {
std::ifstream f(filepath, std::ios::binary);
if (!f) { return false; }
if (offset > 0) {
f.seekg(static_cast<std::streamoff>(offset));
if (!f.good()) {
sink.done();
return true;
}
}
char buf[8192];
f.read(buf, sizeof(buf));
auto n = static_cast<size_t>(f.gcount());
if (n > 0) { return sink.write(buf, n); }
sink.done(); // EOF
return true;
};
return fdp;
}
inline std::pair<size_t, ContentProvider>
make_file_body(const std::string &filepath) {
std::ifstream f(filepath, std::ios::binary | std::ios::ate);
if (!f) { return {0, ContentProvider{}}; }
auto size = static_cast<size_t>(f.tellg());
ContentProvider provider = [filepath](size_t offset, size_t length,
DataSink &sink) -> bool {
std::ifstream f(filepath, std::ios::binary);
if (!f) { return false; }
f.seekg(static_cast<std::streamoff>(offset));
if (!f.good()) { return false; }
char buf[8192];
while (length > 0) {
auto to_read = (std::min)(sizeof(buf), length);
f.read(buf, static_cast<std::streamsize>(to_read));
auto n = static_cast<size_t>(f.gcount());
if (n == 0) { break; }
if (!sink.write(buf, n)) { return false; }
length -= n;
}
return true;
};
return {size, std::move(provider)};
}
using ContentReceiverWithProgress = std::function<bool(
const char *data, size_t data_length, size_t offset, size_t total_length)>;
@ -1010,6 +1269,10 @@ struct Response {
std::string body;
std::string location; // Redirect location
// User-defined context — set by pre-routing/pre-request handlers and read
// by route handlers to pass arbitrary data (e.g. decoded auth tokens).
std::map<std::string, any> user_data;
bool has_header(const std::string &key) const;
std::string get_header_value(const std::string &key, const char *def = "",
size_t id = 0) const;
@ -1115,6 +1378,7 @@ public:
virtual bool is_readable() const = 0;
virtual bool wait_readable() const = 0;
virtual bool wait_writable() const = 0;
virtual bool is_peer_alive() const { return wait_writable(); }
virtual ssize_t read(char *ptr, size_t size) = 0;
virtual ssize_t write(const char *ptr, size_t size) = 0;
@ -1124,6 +1388,11 @@ public:
virtual time_t duration() const = 0;
virtual void set_read_timeout(time_t sec, time_t usec = 0) {
(void)sec;
(void)usec;
}
ssize_t write(const char *ptr);
ssize_t write(const std::string &s);
@ -1146,7 +1415,7 @@ public:
class ThreadPool final : public TaskQueue {
public:
explicit ThreadPool(size_t n, size_t mqr = 0);
explicit ThreadPool(size_t n, size_t max_n = 0, size_t mqr = 0);
ThreadPool(const ThreadPool &) = delete;
~ThreadPool() override = default;
@ -1154,20 +1423,22 @@ public:
void shutdown() override;
private:
struct worker {
explicit worker(ThreadPool &pool);
void worker(bool is_dynamic);
void move_to_finished(std::thread::id id);
void cleanup_finished_threads();
void operator()();
ThreadPool &pool_;
};
friend struct worker;
std::vector<std::thread> threads_;
std::list<std::function<void()>> jobs_;
size_t base_thread_count_;
size_t max_thread_count_;
size_t max_queued_requests_;
size_t idle_thread_count_;
bool shutdown_;
size_t max_queued_requests_ = 0;
std::list<std::function<void()>> jobs_;
std::vector<std::thread> threads_; // base threads
std::list<std::thread> dynamic_threads_; // dynamic threads
std::vector<std::thread>
finished_threads_; // exited dynamic threads awaiting join
std::condition_variable cond_;
std::mutex mutex_;
@ -1294,6 +1565,11 @@ public:
using Expect100ContinueHandler =
std::function<int(const Request &, Response &)>;
using WebSocketHandler =
std::function<void(const Request &, ws::WebSocket &)>;
using SubProtocolSelector =
std::function<std::string(const std::vector<std::string> &protocols)>;
Server();
virtual ~Server();
@ -1311,6 +1587,10 @@ public:
Server &Delete(const std::string &pattern, HandlerWithContentReader handler);
Server &Options(const std::string &pattern, Handler handler);
Server &WebSocket(const std::string &pattern, WebSocketHandler handler);
Server &WebSocket(const std::string &pattern, WebSocketHandler handler,
SubProtocolSelector sub_protocol_selector);
bool set_base_dir(const std::string &dir,
const std::string &mount_point = std::string());
bool set_mount_point(const std::string &mount_point, const std::string &dir,
@ -1386,7 +1666,8 @@ protected:
int remote_port, const std::string &local_addr,
int local_port, bool close_connection,
bool &connection_closed,
const std::function<void(Request &)> &setup_request);
const std::function<void(Request &)> &setup_request,
bool *websocket_upgraded = nullptr);
std::atomic<socket_t> svr_sock_{INVALID_SOCKET};
@ -1488,6 +1769,14 @@ private:
HandlersForContentReader delete_handlers_for_content_reader_;
Handlers options_handlers_;
struct WebSocketHandlerEntry {
std::unique_ptr<detail::MatcherBase> matcher;
WebSocketHandler handler;
SubProtocolSelector sub_protocol_selector;
};
using WebSocketHandlers = std::vector<WebSocketHandlerEntry>;
WebSocketHandlers websocket_handlers_;
HandlerWithResponse error_handler_;
ExceptionHandler exception_handler_;
HandlerWithResponse pre_routing_handler_;
@ -2970,6 +3259,36 @@ struct MbedTlsContext {
} // namespace tls
#endif
#ifdef CPPHTTPLIB_WOLFSSL_SUPPORT
namespace tls {
namespace impl {
// wolfSSL context wrapper (holds WOLFSSL_CTX and related state).
// This struct is accessible via tls::impl for use in SSL context
// setup callbacks (cast ctx_t to tls::impl::WolfSSLContext*).
struct WolfSSLContext {
WOLFSSL_CTX *ctx = nullptr;
bool is_server = false;
bool verify_client = false;
bool has_verify_callback = false;
std::string ca_pem_data_; // accumulated PEM for get_ca_names/get_ca_certs
WolfSSLContext();
~WolfSSLContext();
WolfSSLContext(const WolfSSLContext &) = delete;
WolfSSLContext &operator=(const WolfSSLContext &) = delete;
};
// CA store for wolfSSL: holds raw PEM bytes to allow reloading into any ctx
struct WolfSSLCAStore {
std::string pem_data;
};
} // namespace impl
} // namespace tls
#endif
#endif // CPPHTTPLIB_SSL_ENABLED
namespace stream {
@ -3335,6 +3654,143 @@ private:
} // namespace sse
namespace ws {
enum class Opcode : uint8_t {
Continuation = 0x0,
Text = 0x1,
Binary = 0x2,
Close = 0x8,
Ping = 0x9,
Pong = 0xA,
};
enum class CloseStatus : uint16_t {
Normal = 1000,
GoingAway = 1001,
ProtocolError = 1002,
UnsupportedData = 1003,
NoStatus = 1005,
Abnormal = 1006,
InvalidPayload = 1007,
PolicyViolation = 1008,
MessageTooBig = 1009,
MandatoryExtension = 1010,
InternalError = 1011,
};
enum ReadResult : int { Fail = 0, Text = 1, Binary = 2 };
class WebSocket {
public:
WebSocket(const WebSocket &) = delete;
WebSocket &operator=(const WebSocket &) = delete;
~WebSocket();
ReadResult read(std::string &msg);
bool send(const std::string &data);
bool send(const char *data, size_t len);
void close(CloseStatus status = CloseStatus::Normal,
const std::string &reason = "");
const Request &request() const;
bool is_open() const;
private:
friend class httplib::Server;
friend class WebSocketClient;
WebSocket(Stream &strm, const Request &req, bool is_server)
: strm_(strm), req_(req), is_server_(is_server) {
start_heartbeat();
}
WebSocket(std::unique_ptr<Stream> &&owned_strm, const Request &req,
bool is_server)
: strm_(*owned_strm), owned_strm_(std::move(owned_strm)), req_(req),
is_server_(is_server) {
start_heartbeat();
}
void start_heartbeat();
bool send_frame(Opcode op, const char *data, size_t len, bool fin = true);
Stream &strm_;
std::unique_ptr<Stream> owned_strm_;
Request req_;
bool is_server_;
std::atomic<bool> closed_{false};
std::mutex write_mutex_;
std::thread ping_thread_;
std::mutex ping_mutex_;
std::condition_variable ping_cv_;
};
class WebSocketClient {
public:
explicit WebSocketClient(const std::string &scheme_host_port_path,
const Headers &headers = {});
~WebSocketClient();
WebSocketClient(const WebSocketClient &) = delete;
WebSocketClient &operator=(const WebSocketClient &) = delete;
bool is_valid() const;
bool connect();
ReadResult read(std::string &msg);
bool send(const std::string &data);
bool send(const char *data, size_t len);
void close(CloseStatus status = CloseStatus::Normal,
const std::string &reason = "");
bool is_open() const;
const std::string &subprotocol() const;
void set_read_timeout(time_t sec, time_t usec = 0);
void set_write_timeout(time_t sec, time_t usec = 0);
#ifdef CPPHTTPLIB_SSL_ENABLED
void set_ca_cert_path(const std::string &path);
void set_ca_cert_store(tls::ca_store_t store);
void enable_server_certificate_verification(bool enabled);
#endif
private:
void shutdown_and_close();
bool create_stream(std::unique_ptr<Stream> &strm);
std::string host_;
int port_;
std::string path_;
Headers headers_;
std::string subprotocol_;
bool is_valid_ = false;
socket_t sock_ = INVALID_SOCKET;
std::unique_ptr<WebSocket> ws_;
time_t read_timeout_sec_ = CPPHTTPLIB_WEBSOCKET_READ_TIMEOUT_SECOND;
time_t read_timeout_usec_ = 0;
time_t write_timeout_sec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND;
time_t write_timeout_usec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND;
#ifdef CPPHTTPLIB_SSL_ENABLED
bool is_ssl_ = false;
tls::ctx_t tls_ctx_ = nullptr;
tls::session_t tls_session_ = nullptr;
std::string ca_cert_file_path_;
tls::ca_store_t ca_cert_store_ = nullptr;
bool server_certificate_verification_ = true;
#endif
};
namespace impl {
bool is_valid_utf8(const std::string &s);
bool read_websocket_frame(Stream &strm, Opcode &opcode, std::string &payload,
bool &fin, bool expect_masked, size_t max_len);
} // namespace impl
} // namespace ws
} // namespace httplib