Merge with master
This commit is contained in:
commit
9f9d06407e
|
|
@ -1,8 +1,8 @@
|
|||
ARG UBUNTU_VERSION=24.04
|
||||
|
||||
# This needs to generally match the container host's environment.
|
||||
ARG ROCM_VERSION=7.0
|
||||
ARG AMDGPU_VERSION=7.0
|
||||
ARG ROCM_VERSION=7.2
|
||||
ARG AMDGPU_VERSION=7.2
|
||||
|
||||
# Target the ROCm build image
|
||||
ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete
|
||||
|
|
@ -11,13 +11,12 @@ ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-co
|
|||
FROM ${BASE_ROCM_DEV_CONTAINER} AS build
|
||||
|
||||
# Unless otherwise specified, we make a fat build.
|
||||
# List from https://github.com/ggml-org/llama.cpp/pull/1087#issuecomment-1682807878
|
||||
# This is mostly tied to rocBLAS supported archs.
|
||||
# gfx803, gfx900, gfx906, gfx1032, gfx1101, gfx1102,not officialy supported
|
||||
# check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-6.4.1/reference/system-requirements.html
|
||||
# check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-7.2.0/reference/system-requirements.html
|
||||
# check https://rocm.docs.amd.com/projects/radeon-ryzen/en/latest/docs/compatibility/compatibilityrad/native_linux/native_linux_compatibility.html
|
||||
# check https://rocm.docs.amd.com/projects/radeon-ryzen/en/latest/docs/compatibility/compatibilityryz/native_linux/native_linux_compatibility.html
|
||||
|
||||
ARG ROCM_DOCKER_ARCH='gfx803;gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1010;gfx1030;gfx1032;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx1151'
|
||||
#ARG ROCM_DOCKER_ARCH='gfx1151'
|
||||
ARG ROCM_DOCKER_ARCH='gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1151;gfx1150;gfx1200;gfx1201'
|
||||
|
||||
# Set ROCm architectures
|
||||
ENV AMDGPU_TARGETS=${ROCM_DOCKER_ARCH}
|
||||
|
|
|
|||
|
|
@ -516,6 +516,102 @@ jobs:
|
|||
path: llama-bin-win-sycl-x64.zip
|
||||
name: llama-bin-win-sycl-x64.zip
|
||||
|
||||
ubuntu-22-rocm:
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- ROCM_VERSION: "7.2"
|
||||
gpu_targets: "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1151;gfx1150;gfx1200;gfx1201"
|
||||
build: 'x64'
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
with:
|
||||
key: ubuntu-rocm-cmake-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
run: |
|
||||
sudo apt install -y build-essential git cmake wget
|
||||
|
||||
- name: Setup Legacy ROCm
|
||||
if: matrix.ROCM_VERSION == '7.2'
|
||||
id: legacy_env
|
||||
run: |
|
||||
sudo mkdir --parents --mode=0755 /etc/apt/keyrings
|
||||
wget https://repo.radeon.com/rocm/rocm.gpg.key -O - | \
|
||||
gpg --dearmor | sudo tee /etc/apt/keyrings/rocm.gpg > /dev/null
|
||||
|
||||
sudo tee /etc/apt/sources.list.d/rocm.list << EOF
|
||||
deb [arch=amd64 signed-by=/etc/apt/keyrings/rocm.gpg] https://repo.radeon.com/rocm/apt/${{ matrix.ROCM_VERSION }} jammy main
|
||||
EOF
|
||||
|
||||
sudo tee /etc/apt/preferences.d/rocm-pin-600 << EOF
|
||||
Package: *
|
||||
Pin: release o=repo.radeon.com
|
||||
Pin-Priority: 600
|
||||
EOF
|
||||
|
||||
sudo apt update
|
||||
sudo apt-get install -y libssl-dev rocm-hip-sdk
|
||||
|
||||
- name: Setup TheRock
|
||||
if: matrix.ROCM_VERSION != '7.2'
|
||||
id: therock_env
|
||||
run: |
|
||||
wget https://repo.amd.com/rocm/tarball/therock-dist-linux-gfx1151-${{ matrix.ROCM_VERSION }}.tar.gz
|
||||
mkdir install
|
||||
tar -xf *.tar.gz -C install
|
||||
export ROCM_PATH=$(pwd)/install
|
||||
echo ROCM_PATH=$ROCM_PATH >> $GITHUB_ENV
|
||||
echo PATH=$PATH:$ROCM_PATH/bin >> $GITHUB_ENV
|
||||
echo LD_LIBRARY_PATH=$ROCM_PATH/lib:$ROCM_PATH/llvm/lib:$ROCM_PATH/lib/rocprofiler-systems >> $GITHUB_ENV
|
||||
|
||||
- name: Build with native CMake HIP support
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake -B build -S . \
|
||||
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
|
||||
-DCMAKE_HIP_FLAGS="-mllvm --amdgpu-unroll-threshold-local=600" \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DGGML_BACKEND_DL=ON \
|
||||
-DGGML_NATIVE=OFF \
|
||||
-DCMAKE_INSTALL_RPATH='$ORIGIN' \
|
||||
-DCMAKE_BUILD_WITH_INSTALL_RPATH=ON \
|
||||
-DGGML_CPU_ALL_VARIANTS=ON \
|
||||
-DGPU_TARGETS="${{ matrix.gpu_targets }}" \
|
||||
-DGGML_HIP=ON \
|
||||
-DHIP_PLATFORM=amd \
|
||||
-DGGML_HIP_ROCWMMA_FATTN=ON \
|
||||
${{ env.CMAKE_ARGS }}
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
- name: Determine tag name
|
||||
id: tag
|
||||
uses: ./.github/actions/get-tag-name
|
||||
|
||||
- name: Pack artifacts
|
||||
id: pack_artifacts
|
||||
run: |
|
||||
cp LICENSE ./build/bin/
|
||||
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-rocm-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-rocm-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}.tar.gz
|
||||
name: llama-bin-ubuntu-rocm-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}.tar.gz
|
||||
|
||||
windows-hip:
|
||||
runs-on: windows-2022
|
||||
|
||||
|
|
@ -784,6 +880,7 @@ jobs:
|
|||
- windows-cuda
|
||||
- windows-sycl
|
||||
- windows-hip
|
||||
- ubuntu-22-rocm
|
||||
- ubuntu-22-cpu
|
||||
- ubuntu-22-vulkan
|
||||
- macOS-arm64
|
||||
|
|
@ -868,6 +965,7 @@ jobs:
|
|||
**Linux:**
|
||||
- [Ubuntu x64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-x64.tar.gz)
|
||||
- [Ubuntu x64 (Vulkan)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz)
|
||||
- [Ubuntu x64 (ROCm 7.2)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-rocm-7.2-x64.tar.gz)
|
||||
- [Ubuntu s390x (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-s390x.tar.gz)
|
||||
|
||||
**Windows:**
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target directories.
|
||||
cmake_minimum_required(VERSION 3.14...3.28) # for add_link_options and implicit target directories.
|
||||
project("llama.cpp" C CXX)
|
||||
include(CheckIncludeFileCXX)
|
||||
|
||||
|
|
|
|||
|
|
@ -803,7 +803,7 @@ inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, cons
|
|||
}
|
||||
|
||||
// remove potential partial suffix
|
||||
if (builder.pos() == builder.input().size()) {
|
||||
if (builder.pos() == builder.input().size() && builder.is_partial()) {
|
||||
if (unclosed_reasoning_content.empty()) {
|
||||
rstrip(content);
|
||||
trim_potential_partial_word(content);
|
||||
|
|
|
|||
|
|
@ -893,23 +893,6 @@ static void common_chat_parse_minimax_m2(common_chat_msg_parser & builder) {
|
|||
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
|
||||
}
|
||||
|
||||
static void common_chat_parse_qwen3_coder_xml(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form = ([]() {
|
||||
xml_tool_call_format form {};
|
||||
form.scope_start = "<tool_call>";
|
||||
form.tool_start = "<function=";
|
||||
form.tool_sep = ">";
|
||||
form.key_start = "<parameter=";
|
||||
form.key_val_sep = ">";
|
||||
form.val_end = "</parameter>";
|
||||
form.tool_end = "</function>";
|
||||
form.scope_end = "</tool_call>";
|
||||
form.trim_raw_argval = true;
|
||||
return form;
|
||||
})();
|
||||
builder.consume_reasoning_with_xml_tool_calls(form);
|
||||
}
|
||||
|
||||
static void common_chat_parse_kimi_k2(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form = ([]() {
|
||||
xml_tool_call_format form {};
|
||||
|
|
@ -1590,9 +1573,6 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
|
|||
case COMMON_CHAT_FORMAT_KIMI_K2:
|
||||
common_chat_parse_kimi_k2(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_QWEN3_CODER_XML:
|
||||
common_chat_parse_qwen3_coder_xml(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_APRIEL_1_5:
|
||||
common_chat_parse_apriel_1_5(builder);
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -736,7 +736,6 @@ const char * common_chat_format_name(common_chat_format format) {
|
|||
case COMMON_CHAT_FORMAT_MINIMAX_M2: return "MiniMax-M2";
|
||||
case COMMON_CHAT_FORMAT_GLM_4_5: return "GLM 4.5";
|
||||
case COMMON_CHAT_FORMAT_KIMI_K2: return "Kimi K2";
|
||||
case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: return "Qwen3 Coder";
|
||||
case COMMON_CHAT_FORMAT_APRIEL_1_5: return "Apriel 1.5";
|
||||
case COMMON_CHAT_FORMAT_XIAOMI_MIMO: return "Xiaomi MiMo";
|
||||
case COMMON_CHAT_FORMAT_SOLAR_OPEN: return "Solar Open";
|
||||
|
|
@ -1522,14 +1521,17 @@ static common_chat_params common_chat_params_init_nemotron_v2(const common_chat_
|
|||
return data;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_nemotron_v3(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
static common_chat_params common_chat_params_init_qwen3_coder(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = apply(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_CONSTRUCTED;
|
||||
|
||||
// Nemotron Nano 3 and Step-3.5-Flash use the Qwen3 Coder tool calling with thinking
|
||||
bool supports_reasoning = (tmpl.source().find("<think>") != std::string::npos);
|
||||
|
||||
// Handle thinking tags appropriately based on inputs.enable_thinking
|
||||
if (string_ends_with(data.prompt, "<think>\n")) {
|
||||
if (supports_reasoning && string_ends_with(data.prompt, "<think>\n")) {
|
||||
if (!inputs.enable_thinking) {
|
||||
data.prompt += "</think>";
|
||||
} else {
|
||||
|
|
@ -1538,19 +1540,21 @@ static common_chat_params common_chat_params_init_nemotron_v3(const common_chat_
|
|||
}
|
||||
|
||||
data.preserved_tokens = {
|
||||
"<think>",
|
||||
"</think>",
|
||||
"<tool_call>",
|
||||
"</tool_call>",
|
||||
};
|
||||
|
||||
if (supports_reasoning) {
|
||||
data.preserved_tokens.insert(data.preserved_tokens.end(), {"<think>", "</think>"});
|
||||
}
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
auto include_grammar = true;
|
||||
|
||||
auto parser = build_chat_peg_constructed_parser([&](auto & p) {
|
||||
auto reasoning = p.eps();
|
||||
if (inputs.enable_thinking && extract_reasoning) {
|
||||
if (supports_reasoning && inputs.enable_thinking && extract_reasoning) {
|
||||
auto reasoning_content = p.reasoning(p.until("</think>")) + ("</think>" | p.end());
|
||||
if (data.thinking_forced_open) {
|
||||
reasoning = reasoning_content;
|
||||
|
|
@ -1888,38 +1892,6 @@ static common_chat_params common_chat_params_init_minimax_m2(const common_chat_t
|
|||
return data;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_qwen3_coder_xml(const common_chat_template & tmpl, const struct templates_params & params) {
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
data.prompt = apply(tmpl, params);
|
||||
data.format = COMMON_CHAT_FORMAT_QWEN3_CODER_XML;
|
||||
|
||||
data.preserved_tokens = {
|
||||
"<tool_call>",
|
||||
"</tool_call>",
|
||||
"<function=",
|
||||
"</function>",
|
||||
"<parameter=",
|
||||
"</parameter>",
|
||||
};
|
||||
|
||||
// build grammar for tool call
|
||||
static const xml_tool_call_format form {
|
||||
/* form.scope_start = */ "<tool_call>\n",
|
||||
/* form.tool_start = */ "<function=",
|
||||
/* form.tool_sep = */ ">\n",
|
||||
/* form.key_start = */ "<parameter=",
|
||||
/* form.key_val_sep = */ ">\n",
|
||||
/* form.val_end = */ "\n</parameter>\n",
|
||||
/* form.tool_end = */ "</function>\n",
|
||||
/* form.scope_end = */ "</tool_call>",
|
||||
};
|
||||
build_grammar_xml_tool_call(data, params.tools, form);
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_kimi_k2(const common_chat_template & tmpl, const struct templates_params & params) {
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
|
@ -3147,13 +3119,7 @@ static common_chat_params common_chat_templates_apply_jinja(
|
|||
src.find("<function=") != std::string::npos &&
|
||||
src.find("<parameter=") != std::string::npos) {
|
||||
workaround::func_args_not_string(params.messages);
|
||||
// Models with <think> support (Step-3.5-Flash, Nemotron 3 Nano) use the
|
||||
// Nemotron v3 PEG parser for streaming and schema-aware parameter parsing.
|
||||
// Qwen3-Coder has no <think> in its template.
|
||||
if (src.find("<think>") != std::string::npos) {
|
||||
return common_chat_params_init_nemotron_v3(tmpl, params);
|
||||
}
|
||||
return common_chat_params_init_qwen3_coder_xml(tmpl, params);
|
||||
return common_chat_params_init_qwen3_coder(tmpl, params);
|
||||
}
|
||||
|
||||
// Xiaomi MiMo format detection (must come before Hermes 2 Pro)
|
||||
|
|
|
|||
|
|
@ -128,7 +128,6 @@ enum common_chat_format {
|
|||
COMMON_CHAT_FORMAT_GLM_4_5,
|
||||
COMMON_CHAT_FORMAT_MINIMAX_M2,
|
||||
COMMON_CHAT_FORMAT_KIMI_K2,
|
||||
COMMON_CHAT_FORMAT_QWEN3_CODER_XML,
|
||||
COMMON_CHAT_FORMAT_APRIEL_1_5,
|
||||
COMMON_CHAT_FORMAT_XIAOMI_MIMO,
|
||||
COMMON_CHAT_FORMAT_SOLAR_OPEN,
|
||||
|
|
|
|||
|
|
@ -1760,3 +1760,65 @@ float lr_opt::get_lr(float epoch) const {
|
|||
LOG_INF("epoch %.2g lr=%.2g\n", epoch, r);
|
||||
return r;
|
||||
}
|
||||
|
||||
bool common_replay_last_token(struct llama_context * ctx, llama_token last_token, int32_t pos) {
|
||||
llama_batch batch = llama_batch_get_one(&last_token, 1);
|
||||
batch.pos = &pos;
|
||||
if (llama_decode(ctx, batch)) {
|
||||
LOG_ERR("%s: failed to replay last token\n", __func__);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool common_prompt_batch_decode(
|
||||
struct llama_context * ctx,
|
||||
const std::vector<llama_token> & tokens,
|
||||
int & n_past,
|
||||
int n_batch,
|
||||
std::string_view state_path,
|
||||
bool save_state) {
|
||||
const int n_eval = tokens.size();
|
||||
if (n_eval == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (save_state && n_eval > 1) {
|
||||
const int n_tokens_before_last = n_eval - 1;
|
||||
|
||||
GGML_ASSERT(n_eval <= n_batch);
|
||||
|
||||
// Decode all but the last token so we can save the memory state before decoding the last token.
|
||||
// This is done so we can restore the session state later and replay the last token.
|
||||
// Memory implementations in recurrent/hybrid models don't support removing tokens from their
|
||||
// memory, so we can't just remove the last token from the memory and replay the last token which
|
||||
// is the reason for this logic.
|
||||
if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(tokens.data()), n_tokens_before_last))) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
n_past += n_tokens_before_last;
|
||||
|
||||
llama_state_save_file(ctx, state_path.data(), tokens.data(), n_tokens_before_last);
|
||||
LOG_INF("saved session before last token to %s, n_tokens = %d\n", state_path.data(), n_tokens_before_last);
|
||||
|
||||
llama_token last_token = tokens.back();
|
||||
llama_batch batch = llama_batch_get_one(&last_token, 1);
|
||||
int32_t pos = n_past;
|
||||
batch.pos = &pos;
|
||||
|
||||
if (llama_decode(ctx, batch)) {
|
||||
LOG_ERR("%s : failed to eval last token\n", __func__);
|
||||
return false;
|
||||
}
|
||||
n_past++;
|
||||
} else {
|
||||
if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(tokens.data()), n_eval))) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
n_past += n_eval;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -804,6 +804,23 @@ void common_batch_add(
|
|||
const std::vector<llama_seq_id> & seq_ids,
|
||||
bool logits);
|
||||
|
||||
// decodes a single batch of tokens for a prompt and manages session tokens
|
||||
//
|
||||
// Note: We save state before the last token so that we can replay it to ensure
|
||||
// compatibility with all memory types. Recurrent/hybrid models cannot remove
|
||||
// tokens from memory, so this approach works across all model architectures.
|
||||
bool common_prompt_batch_decode(
|
||||
struct llama_context * ctx,
|
||||
const std::vector<llama_token> & embd,
|
||||
int & n_past,
|
||||
int n_batch,
|
||||
std::string_view state_path,
|
||||
bool save_state);
|
||||
|
||||
// replays the last token after loading state to regenerate logits
|
||||
// used after loading session state to ensure the sampling context has valid logits
|
||||
bool common_replay_last_token(struct llama_context * ctx, llama_token last_token, int32_t pos);
|
||||
|
||||
//
|
||||
// Vocab utils
|
||||
//
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ value identifier::execute_impl(context & ctx) {
|
|||
auto builtins = global_builtins();
|
||||
if (!it->is_undefined()) {
|
||||
if (ctx.is_get_stats) {
|
||||
it->stats.used = true;
|
||||
value_t::stats_t::mark_used(it);
|
||||
}
|
||||
JJ_DEBUG("Identifier '%s' found, type = %s", val.c_str(), it->type().c_str());
|
||||
return it;
|
||||
|
|
@ -277,7 +277,7 @@ value binary_expression::execute_impl(context & ctx) {
|
|||
static value try_builtin_func(context & ctx, const std::string & name, value & input, bool undef_on_missing = false) {
|
||||
JJ_DEBUG("Trying built-in function '%s' for type %s", name.c_str(), input->type().c_str());
|
||||
if (ctx.is_get_stats) {
|
||||
input->stats.used = true;
|
||||
value_t::stats_t::mark_used(input);
|
||||
input->stats.ops.insert(name);
|
||||
}
|
||||
auto builtins = input->get_builtins();
|
||||
|
|
@ -448,7 +448,7 @@ value for_statement::execute_impl(context & ctx) {
|
|||
|
||||
// mark the variable being iterated as used for stats
|
||||
if (ctx.is_get_stats) {
|
||||
iterable_val->stats.used = true;
|
||||
value_t::stats_t::mark_used(iterable_val);
|
||||
iterable_val->stats.ops.insert("array_access");
|
||||
}
|
||||
|
||||
|
|
@ -470,7 +470,7 @@ value for_statement::execute_impl(context & ctx) {
|
|||
items.push_back(std::move(tuple));
|
||||
}
|
||||
if (ctx.is_get_stats) {
|
||||
iterable_val->stats.used = true;
|
||||
value_t::stats_t::mark_used(iterable_val);
|
||||
iterable_val->stats.ops.insert("object_access");
|
||||
}
|
||||
} else {
|
||||
|
|
@ -480,7 +480,7 @@ value for_statement::execute_impl(context & ctx) {
|
|||
items.push_back(item);
|
||||
}
|
||||
if (ctx.is_get_stats) {
|
||||
iterable_val->stats.used = true;
|
||||
value_t::stats_t::mark_used(iterable_val);
|
||||
iterable_val->stats.ops.insert("array_access");
|
||||
}
|
||||
}
|
||||
|
|
@ -817,8 +817,9 @@ value member_expression::execute_impl(context & ctx) {
|
|||
}
|
||||
|
||||
if (ctx.is_get_stats && val && object && property) {
|
||||
val->stats.used = true;
|
||||
object->stats.used = true;
|
||||
value_t::stats_t::mark_used(val);
|
||||
value_t::stats_t::mark_used(object);
|
||||
value_t::stats_t::mark_used(property);
|
||||
if (is_val<value_int>(property)) {
|
||||
object->stats.ops.insert("array_access");
|
||||
} else if (is_val<value_string>(property)) {
|
||||
|
|
|
|||
|
|
@ -161,6 +161,11 @@ static value tojson(const func_args & args) {
|
|||
value val_separators = args.get_kwarg_or_pos("separators", 3);
|
||||
value val_sort = args.get_kwarg_or_pos("sort_keys", 4);
|
||||
int indent = -1;
|
||||
if (args.ctx.is_get_stats) {
|
||||
// mark as used (recursively) for stats
|
||||
auto val_input = args.get_pos(0);
|
||||
value_t::stats_t::mark_used(const_cast<value&>(val_input), true);
|
||||
}
|
||||
if (is_val<value_int>(val_indent)) {
|
||||
indent = static_cast<int>(val_indent->as_int());
|
||||
}
|
||||
|
|
@ -891,6 +896,11 @@ const func_builtins & value_array_t::get_builtins() const {
|
|||
}},
|
||||
{"string", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_array>();
|
||||
if (args.ctx.is_get_stats) {
|
||||
// mark as used (recursively) for stats
|
||||
auto val_input = args.get_pos(0);
|
||||
value_t::stats_t::mark_used(const_cast<value&>(val_input), true);
|
||||
}
|
||||
return mk_val<value_string>(args.get_pos(0)->as_string());
|
||||
}},
|
||||
{"tojson", tojson},
|
||||
|
|
@ -1046,6 +1056,11 @@ const func_builtins & value_object_t::get_builtins() const {
|
|||
{"tojson", tojson},
|
||||
{"string", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_object>();
|
||||
if (args.ctx.is_get_stats) {
|
||||
// mark as used (recursively) for stats
|
||||
auto val_input = args.get_pos(0);
|
||||
value_t::stats_t::mark_used(const_cast<value&>(val_input), true);
|
||||
}
|
||||
return mk_val<value_string>(args.get_pos(0)->as_string());
|
||||
}},
|
||||
{"length", [](const func_args & args) -> value {
|
||||
|
|
@ -1358,4 +1373,21 @@ std::string value_to_string_repr(const value & val) {
|
|||
}
|
||||
}
|
||||
|
||||
// stats utility
|
||||
void value_t::stats_t::mark_used(value & val, bool deep) {
|
||||
val->stats.used = true;
|
||||
if (deep) {
|
||||
if (is_val<value_array>(val)) {
|
||||
for (auto & item : val->val_arr) {
|
||||
mark_used(item, deep);
|
||||
}
|
||||
} else if (is_val<value_object>(val)) {
|
||||
for (auto & pair : val->val_obj) {
|
||||
mark_used(pair.first, deep);
|
||||
mark_used(pair.second, deep);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jinja
|
||||
|
|
|
|||
|
|
@ -118,6 +118,8 @@ struct value_t {
|
|||
bool used = false;
|
||||
// ops can be builtin calls or operators: "array_access", "object_access"
|
||||
std::set<std::string> ops;
|
||||
// utility to recursively mark value and its children as used
|
||||
static void mark_used(value & val, bool deep = false);
|
||||
} stats;
|
||||
|
||||
value_t() = default;
|
||||
|
|
|
|||
|
|
@ -1274,6 +1274,9 @@ class TextModel(ModelBase):
|
|||
if chkhsh == "b4b8ca1f9769494fbd956ebc4c249de6131fb277a4a3345a7a92c7dd7a55808d":
|
||||
# ref: https://huggingface.co/jdopensource/JoyAI-LLM-Flash
|
||||
res = "joyai-llm"
|
||||
if chkhsh == "e4d54df1ebc1f2b91acd986c5b51aa50837d5faf7c7398e73c1f9e9ee5d19869":
|
||||
# ref: https://huggingface.co/kakaocorp/kanana-2-30b-a3b-instruct-2601
|
||||
res = "kanana2"
|
||||
|
||||
if res is None:
|
||||
logger.warning("\n")
|
||||
|
|
|
|||
|
|
@ -152,6 +152,7 @@ models = [
|
|||
{"name": "exaone-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/K-EXAONE-236B-A23B", },
|
||||
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.5-9B-Instruct", },
|
||||
{"name": "joyai-llm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jdopensource/JoyAI-LLM-Flash", },
|
||||
{"name": "kanana2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/kakaocorp/kanana-2-30b-a3b-instruct-2601", },
|
||||
]
|
||||
|
||||
# some models are known to be broken upstream, so we will skip them as exceptions
|
||||
|
|
|
|||
|
|
@ -77,7 +77,10 @@ causal-verify-embeddings: causal-run-original-embeddings causal-run-converted-em
|
|||
@./scripts/causal/compare-embeddings-logits.sh
|
||||
|
||||
causal-inspect-original-model:
|
||||
@./scripts/utils/inspect-org-model.py
|
||||
@./scripts/utils/inspect-org-model.py --list-all -s
|
||||
|
||||
causal-list-original-model-tensors:
|
||||
@./scripts/utils/inspect-org-model.py --list-all-short -s
|
||||
|
||||
causal-inspect-converted-model:
|
||||
@./scripts/utils/inspect-converted-model.sh
|
||||
|
|
@ -153,7 +156,7 @@ embedding-verify-logits-st: embedding-run-original-model-st embedding-run-conver
|
|||
|
||||
embedding-inspect-original-model:
|
||||
$(call validate_embedding_model_path,embedding-inspect-original-model)
|
||||
@EMBEDDING_MODEL_PATH="$(EMBEDDING_MODEL_PATH)" ./scripts/utils/inspect-org-model.py -m ${EMBEDDING_MODEL_PATH}
|
||||
@EMBEDDING_MODEL_PATH="$(EMBEDDING_MODEL_PATH)" ./scripts/utils/inspect-org-model.py -m ${EMBEDDING_MODEL_PATH} --list-all -s
|
||||
|
||||
embedding-inspect-converted-model:
|
||||
@CONVERTED_EMBEDDING_MODEL="$(CONVERTED_EMBEDDING_MODEL)" ./scripts/utils/inspect-converted-model.sh ${CONVERTED_EMBEDDING_MODEL}
|
||||
|
|
|
|||
|
|
@ -1,67 +1,290 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import struct
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from safetensors import safe_open
|
||||
from collections import defaultdict
|
||||
|
||||
parser = argparse.ArgumentParser(description='Process model with specified path')
|
||||
parser.add_argument('--model-path', '-m', help='Path to the model')
|
||||
args = parser.parse_args()
|
||||
|
||||
model_path = os.environ.get('MODEL_PATH', args.model_path)
|
||||
if model_path is None:
|
||||
parser.error("Model path must be specified either via --model-path argument or MODEL_PATH environment variable")
|
||||
MODEL_SAFETENSORS_FILE = "model.safetensors"
|
||||
MODEL_SAFETENSORS_INDEX = "model.safetensors.index.json"
|
||||
|
||||
# Check if there's an index file (multi-file model)
|
||||
index_path = os.path.join(model_path, "model.safetensors.index.json")
|
||||
single_file_path = os.path.join(model_path, "model.safetensors")
|
||||
DTYPE_SIZES = {
|
||||
"F64": 8, "I64": 8, "U64": 8,
|
||||
"F32": 4, "I32": 4, "U32": 4,
|
||||
"F16": 2, "BF16": 2, "I16": 2, "U16": 2,
|
||||
"I8": 1, "U8": 1, "BOOL": 1,
|
||||
"F8_E4M3": 1, "F8_E5M2": 1,
|
||||
}
|
||||
|
||||
if os.path.exists(index_path):
|
||||
# Multi-file model
|
||||
print("Multi-file model detected")
|
||||
SIZE_UNITS = ['B', 'KB', 'MB', 'GB', 'TB']
|
||||
|
||||
with open(index_path, 'r') as f:
|
||||
index_data = json.load(f)
|
||||
|
||||
# Get the weight map (tensor_name -> file_name)
|
||||
weight_map = index_data.get("weight_map", {})
|
||||
def get_weight_map(model_path: Path) -> Optional[dict[str, str]]:
|
||||
index_file = model_path / MODEL_SAFETENSORS_INDEX
|
||||
|
||||
# Group tensors by file for efficient processing
|
||||
file_tensors = defaultdict(list)
|
||||
for tensor_name, file_name in weight_map.items():
|
||||
file_tensors[file_name].append(tensor_name)
|
||||
if index_file.exists():
|
||||
with open(index_file, 'r') as f:
|
||||
index = json.load(f)
|
||||
return index.get("weight_map", {})
|
||||
|
||||
print("Tensors in model:")
|
||||
return None
|
||||
|
||||
# Process each shard file
|
||||
for file_name, tensor_names in file_tensors.items():
|
||||
file_path = os.path.join(model_path, file_name)
|
||||
print(f"\n--- From {file_name} ---")
|
||||
|
||||
with safe_open(file_path, framework="pt") as f:
|
||||
for tensor_name in sorted(tensor_names):
|
||||
tensor = f.get_tensor(tensor_name)
|
||||
print(f"- {tensor_name} : shape = {tensor.shape}, dtype = {tensor.dtype}")
|
||||
def get_all_tensor_names(model_path: Path) -> list[str]:
|
||||
weight_map = get_weight_map(model_path)
|
||||
|
||||
elif os.path.exists(single_file_path):
|
||||
# Single file model (original behavior)
|
||||
print("Single-file model detected")
|
||||
if weight_map is not None:
|
||||
return list(weight_map.keys())
|
||||
|
||||
with safe_open(single_file_path, framework="pt") as f:
|
||||
keys = f.keys()
|
||||
print("Tensors in model:")
|
||||
for key in sorted(keys):
|
||||
tensor = f.get_tensor(key)
|
||||
print(f"- {key} : shape = {tensor.shape}, dtype = {tensor.dtype}")
|
||||
single_file = model_path / MODEL_SAFETENSORS_FILE
|
||||
if single_file.exists():
|
||||
try:
|
||||
with safe_open(single_file, framework="pt", device="cpu") as f:
|
||||
return list(f.keys())
|
||||
except Exception as e:
|
||||
print(f"Error reading {single_file}: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
else:
|
||||
print(f"Error: Neither 'model.safetensors.index.json' nor 'model.safetensors' found in {model_path}")
|
||||
print("Available files:")
|
||||
if os.path.exists(model_path):
|
||||
for item in sorted(os.listdir(model_path)):
|
||||
print(f" {item}")
|
||||
print(f"Error: No safetensors files found in {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def find_tensor_file(model_path: Path, tensor_name: str) -> Optional[str]:
|
||||
weight_map = get_weight_map(model_path)
|
||||
|
||||
if weight_map is not None:
|
||||
return weight_map.get(tensor_name)
|
||||
|
||||
single_file = model_path / MODEL_SAFETENSORS_FILE
|
||||
if single_file.exists():
|
||||
return single_file.name
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def read_safetensors_header(file_path: Path) -> dict:
|
||||
with open(file_path, 'rb') as f:
|
||||
header_size = struct.unpack('<Q', f.read(8))[0]
|
||||
return json.loads(f.read(header_size))
|
||||
|
||||
|
||||
def get_tensor_size_bytes(tensor_meta: dict) -> int:
|
||||
offsets = tensor_meta.get("data_offsets")
|
||||
if offsets and len(offsets) == 2:
|
||||
return offsets[1] - offsets[0]
|
||||
n_elements = 1
|
||||
for d in tensor_meta.get("shape", []):
|
||||
n_elements *= d
|
||||
return n_elements * DTYPE_SIZES.get(tensor_meta.get("dtype", "F32"), 4)
|
||||
|
||||
|
||||
def format_size(size_bytes: int) -> str:
|
||||
val = float(size_bytes)
|
||||
for unit in SIZE_UNITS[:-1]:
|
||||
if val < 1024.0:
|
||||
return f"{val:.2f} {unit}"
|
||||
val /= 1024.0
|
||||
return f"{val:.2f} {SIZE_UNITS[-1]}"
|
||||
|
||||
|
||||
def get_all_tensor_metadata(model_path: Path) -> dict[str, dict]:
|
||||
weight_map = get_weight_map(model_path)
|
||||
|
||||
if weight_map is not None:
|
||||
file_to_tensors: dict[str, list[str]] = {}
|
||||
for tensor_name, file_name in weight_map.items():
|
||||
file_to_tensors.setdefault(file_name, []).append(tensor_name)
|
||||
|
||||
all_metadata: dict[str, dict] = {}
|
||||
for file_name, tensor_names in file_to_tensors.items():
|
||||
try:
|
||||
header = read_safetensors_header(model_path / file_name)
|
||||
for tensor_name in tensor_names:
|
||||
if tensor_name in header:
|
||||
all_metadata[tensor_name] = header[tensor_name]
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not read header from {file_name}: {e}", file=sys.stderr)
|
||||
return all_metadata
|
||||
|
||||
single_file = model_path / MODEL_SAFETENSORS_FILE
|
||||
if single_file.exists():
|
||||
try:
|
||||
header = read_safetensors_header(single_file)
|
||||
return {k: v for k, v in header.items() if k != "__metadata__"}
|
||||
except Exception as e:
|
||||
print(f"Error reading {single_file}: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Error: No safetensors files found in {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def normalize_tensor_name(tensor_name: str) -> str:
|
||||
normalized = re.sub(r'\.\d+\.', '.#.', tensor_name)
|
||||
normalized = re.sub(r'\.\d+$', '.#', normalized)
|
||||
return normalized
|
||||
|
||||
|
||||
def list_all_tensors(
|
||||
model_path: Path,
|
||||
short: bool = False,
|
||||
show_sizes: bool = False,
|
||||
):
|
||||
tensor_names = get_all_tensor_names(model_path)
|
||||
|
||||
metadata: Optional[dict[str, dict]] = None
|
||||
if show_sizes:
|
||||
metadata = get_all_tensor_metadata(model_path)
|
||||
|
||||
total_bytes = 0
|
||||
|
||||
if short:
|
||||
seen: dict[str, str] = {}
|
||||
for tensor_name in sorted(tensor_names):
|
||||
normalized = normalize_tensor_name(tensor_name)
|
||||
if normalized not in seen:
|
||||
seen[normalized] = tensor_name
|
||||
display_pairs = list(sorted(seen.items()))
|
||||
name_width = max((len(n) for n, _ in display_pairs), default=0)
|
||||
for normalized, first_name in display_pairs:
|
||||
if metadata and first_name in metadata:
|
||||
m = metadata[first_name]
|
||||
size = get_tensor_size_bytes(m)
|
||||
total_bytes += size
|
||||
print(f"{normalized:{name_width}} {m.get('dtype', '?'):6s} {str(m.get('shape', '')):30s} {format_size(size)}")
|
||||
else:
|
||||
print(normalized)
|
||||
else:
|
||||
print(f" Directory {model_path} does not exist")
|
||||
exit(1)
|
||||
name_width = max((len(n) for n in tensor_names), default=0)
|
||||
for tensor_name in sorted(tensor_names):
|
||||
if metadata and tensor_name in metadata:
|
||||
m = metadata[tensor_name]
|
||||
size = get_tensor_size_bytes(m)
|
||||
total_bytes += size
|
||||
print(f"{tensor_name:{name_width}} {m.get('dtype', '?'):6s} {str(m.get('shape', '')):30s} {format_size(size)}")
|
||||
else:
|
||||
print(tensor_name)
|
||||
|
||||
if show_sizes:
|
||||
print(f"\nTotal: {format_size(total_bytes)}")
|
||||
|
||||
|
||||
def print_tensor_info(model_path: Path, tensor_name: str, num_values: Optional[int] = None):
|
||||
tensor_file = find_tensor_file(model_path, tensor_name)
|
||||
|
||||
if tensor_file is None:
|
||||
print(f"Error: Could not find tensor '{tensor_name}' in model index")
|
||||
print(f"Model path: {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
file_path = model_path / tensor_file
|
||||
|
||||
try:
|
||||
header = read_safetensors_header(file_path)
|
||||
tensor_meta = header.get(tensor_name, {})
|
||||
dtype_str = tensor_meta.get("dtype")
|
||||
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
if tensor_name in f.keys():
|
||||
tensor_slice = f.get_slice(tensor_name)
|
||||
shape = tensor_slice.get_shape()
|
||||
print(f"Tensor: {tensor_name}")
|
||||
print(f"File: {tensor_file}")
|
||||
print(f"Shape: {shape}")
|
||||
if dtype_str:
|
||||
print(f"Dtype: {dtype_str}")
|
||||
if tensor_meta:
|
||||
print(f"Size: {format_size(get_tensor_size_bytes(tensor_meta))}")
|
||||
if num_values is not None:
|
||||
tensor = f.get_tensor(tensor_name)
|
||||
if not dtype_str:
|
||||
print(f"Dtype: {tensor.dtype}")
|
||||
flat = tensor.flatten()
|
||||
n = min(num_values, flat.numel())
|
||||
print(f"Values: {flat[:n].tolist()}")
|
||||
else:
|
||||
print(f"Error: Tensor '{tensor_name}' not found in {tensor_file}")
|
||||
sys.exit(1)
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"Error: The file '{file_path}' was not found.")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Print tensor information from a safetensors model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"tensor_name",
|
||||
nargs="?",
|
||||
help="Name of the tensor to inspect"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m", "--model-path",
|
||||
type=Path,
|
||||
help="Path to the model directory (default: MODEL_PATH environment variable)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l", "--list-all-short",
|
||||
action="store_true",
|
||||
help="List unique tensor patterns (layer numbers replaced with #)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-la", "--list-all",
|
||||
action="store_true",
|
||||
help="List all tensor names with actual layer numbers"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-n", "--num-values",
|
||||
nargs="?",
|
||||
const=10,
|
||||
default=None,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="Print the first N values of the tensor flattened (default: 10 if flag is given without a number)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-s", "--sizes",
|
||||
action="store_true",
|
||||
help="Show dtype, shape, and size for each tensor when listing"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
model_path = args.model_path
|
||||
if model_path is None:
|
||||
model_path_str = os.environ.get("MODEL_PATH")
|
||||
if model_path_str is None:
|
||||
print("Error: --model-path not provided and MODEL_PATH environment variable not set")
|
||||
sys.exit(1)
|
||||
model_path = Path(model_path_str)
|
||||
|
||||
if not model_path.exists():
|
||||
print(f"Error: Model path does not exist: {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
if not model_path.is_dir():
|
||||
print(f"Error: Model path is not a directory: {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
if args.list_all_short or args.list_all:
|
||||
list_all_tensors(model_path, short=args.list_all_short, show_sizes=args.sizes)
|
||||
else:
|
||||
if args.tensor_name is None:
|
||||
print("Error: tensor_name is required when not using --list-all-short or --list-all")
|
||||
sys.exit(1)
|
||||
print_tensor_info(model_path, args.tensor_name, args.num_values)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -1,174 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from safetensors import safe_open
|
||||
|
||||
|
||||
MODEL_SAFETENSORS_FILE = "model.safetensors"
|
||||
MODEL_SAFETENSORS_INDEX = "model.safetensors.index.json"
|
||||
|
||||
|
||||
def get_weight_map(model_path: Path) -> Optional[dict[str, str]]:
|
||||
index_file = model_path / MODEL_SAFETENSORS_INDEX
|
||||
|
||||
if index_file.exists():
|
||||
with open(index_file, 'r') as f:
|
||||
index = json.load(f)
|
||||
return index.get("weight_map", {})
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_all_tensor_names(model_path: Path) -> list[str]:
|
||||
weight_map = get_weight_map(model_path)
|
||||
|
||||
if weight_map is not None:
|
||||
return list(weight_map.keys())
|
||||
|
||||
single_file = model_path / MODEL_SAFETENSORS_FILE
|
||||
if single_file.exists():
|
||||
try:
|
||||
with safe_open(single_file, framework="pt", device="cpu") as f:
|
||||
return list(f.keys())
|
||||
except Exception as e:
|
||||
print(f"Error reading {single_file}: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Error: No safetensors files found in {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def find_tensor_file(model_path: Path, tensor_name: str) -> Optional[str]:
|
||||
weight_map = get_weight_map(model_path)
|
||||
|
||||
if weight_map is not None:
|
||||
return weight_map.get(tensor_name)
|
||||
|
||||
single_file = model_path / MODEL_SAFETENSORS_FILE
|
||||
if single_file.exists():
|
||||
return single_file.name
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def normalize_tensor_name(tensor_name: str) -> str:
|
||||
normalized = re.sub(r'\.\d+\.', '.#.', tensor_name)
|
||||
normalized = re.sub(r'\.\d+$', '.#', normalized)
|
||||
return normalized
|
||||
|
||||
|
||||
def list_all_tensors(model_path: Path, unique: bool = False):
|
||||
tensor_names = get_all_tensor_names(model_path)
|
||||
|
||||
if unique:
|
||||
seen = set()
|
||||
for tensor_name in sorted(tensor_names):
|
||||
normalized = normalize_tensor_name(tensor_name)
|
||||
if normalized not in seen:
|
||||
seen.add(normalized)
|
||||
print(normalized)
|
||||
else:
|
||||
for tensor_name in sorted(tensor_names):
|
||||
print(tensor_name)
|
||||
|
||||
|
||||
def print_tensor_info(model_path: Path, tensor_name: str, num_values: Optional[int] = None):
|
||||
tensor_file = find_tensor_file(model_path, tensor_name)
|
||||
|
||||
if tensor_file is None:
|
||||
print(f"Error: Could not find tensor '{tensor_name}' in model index")
|
||||
print(f"Model path: {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
file_path = model_path / tensor_file
|
||||
|
||||
try:
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
if tensor_name in f.keys():
|
||||
tensor_slice = f.get_slice(tensor_name)
|
||||
shape = tensor_slice.get_shape()
|
||||
print(f"Tensor: {tensor_name}")
|
||||
print(f"File: {tensor_file}")
|
||||
print(f"Shape: {shape}")
|
||||
if num_values is not None:
|
||||
tensor = f.get_tensor(tensor_name)
|
||||
print(f"Dtype: {tensor.dtype}")
|
||||
flat = tensor.flatten()
|
||||
n = min(num_values, flat.numel())
|
||||
print(f"Values: {flat[:n].tolist()}")
|
||||
else:
|
||||
print(f"Error: Tensor '{tensor_name}' not found in {tensor_file}")
|
||||
sys.exit(1)
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"Error: The file '{file_path}' was not found.")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Print tensor information from a safetensors model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"tensor_name",
|
||||
nargs="?", # optional (if --list is used for example)
|
||||
help="Name of the tensor to inspect"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m", "--model-path",
|
||||
type=Path,
|
||||
help="Path to the model directory (default: MODEL_PATH environment variable)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l", "--list",
|
||||
action="store_true",
|
||||
help="List unique tensor patterns in the model (layer numbers replaced with #)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-n", "--num-values",
|
||||
nargs="?",
|
||||
const=10,
|
||||
default=None,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="Print the first N values of the tensor flattened (default: 10 if flag is given without a number)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
model_path = args.model_path
|
||||
if model_path is None:
|
||||
model_path_str = os.environ.get("MODEL_PATH")
|
||||
if model_path_str is None:
|
||||
print("Error: --model-path not provided and MODEL_PATH environment variable not set")
|
||||
sys.exit(1)
|
||||
model_path = Path(model_path_str)
|
||||
|
||||
if not model_path.exists():
|
||||
print(f"Error: Model path does not exist: {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
if not model_path.is_dir():
|
||||
print(f"Error: Model path is not a directory: {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
if args.list:
|
||||
list_all_tensors(model_path, unique=True)
|
||||
else:
|
||||
if args.tensor_name is None:
|
||||
print("Error: tensor_name is required when not using --list")
|
||||
sys.exit(1)
|
||||
print_tensor_info(model_path, args.tensor_name, args.num_values)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -5,12 +5,15 @@
|
|||
#include <vector>
|
||||
#include <cstdio>
|
||||
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
|
||||
params.prompt = "The quick brown fox";
|
||||
params.sampling.seed = 1234;
|
||||
|
||||
const std::string_view state_file = "dump_state.bin";
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
|
||||
return 1;
|
||||
}
|
||||
|
|
@ -53,35 +56,16 @@ int main(int argc, char ** argv) {
|
|||
// tokenize prompt
|
||||
auto tokens = common_tokenize(ctx, params.prompt, true);
|
||||
|
||||
// prepare the batch
|
||||
llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
|
||||
for (size_t i = 0; i < tokens.size(); i++) {
|
||||
common_batch_add(batch, tokens[i], i, {0}, false);
|
||||
const bool save_state = true;
|
||||
if (!common_prompt_batch_decode(ctx, tokens, n_past, params.n_batch, state_file, save_state)) {
|
||||
return 1;
|
||||
}
|
||||
batch.logits[batch.n_tokens - 1] = true; // generate next token
|
||||
|
||||
// evaluate prompt
|
||||
llama_decode(ctx, batch);
|
||||
n_past += batch.n_tokens;
|
||||
|
||||
// save state (rng, logits, embedding and kv_cache) to file
|
||||
{
|
||||
std::vector<uint8_t> state_mem(llama_state_get_size(ctx));
|
||||
const size_t written = llama_state_get_data(ctx, state_mem.data(), state_mem.size());
|
||||
|
||||
FILE *fp_write = fopen("dump_state.bin", "wb");
|
||||
fwrite(state_mem.data(), 1, written, fp_write);
|
||||
fclose(fp_write);
|
||||
|
||||
fprintf(stderr, "%s : serialized state into %zd out of a maximum of %zd bytes\n", __func__, written, state_mem.size());
|
||||
}
|
||||
|
||||
// save state (last tokens)
|
||||
const auto n_past_saved = n_past;
|
||||
|
||||
// first run
|
||||
printf("\nfirst run: %s", params.prompt.c_str());
|
||||
|
||||
llama_batch batch = llama_batch_init(1, 0, 1);
|
||||
|
||||
for (auto i = 0; i < params.n_predict; i++) {
|
||||
auto next_token = llama_sampler_sample(smpl, ctx, -1);
|
||||
auto next_token_str = common_token_to_piece(ctx, next_token);
|
||||
|
|
@ -111,27 +95,23 @@ int main(int argc, char ** argv) {
|
|||
|
||||
printf("\nsecond run: %s", params.prompt.c_str());
|
||||
|
||||
// load state (rng, logits, embedding and kv_cache) from file
|
||||
{
|
||||
std::vector<uint8_t> state_mem;
|
||||
// load state from file
|
||||
std::vector<llama_token> unused_sts(tokens.size()); // unused session tokens.
|
||||
size_t n_token_count_out = 0;
|
||||
|
||||
FILE * fp_read = fopen("dump_state.bin", "rb");
|
||||
fseek(fp_read, 0, SEEK_END);
|
||||
state_mem.resize(ftell(fp_read));
|
||||
fseek(fp_read, 0, SEEK_SET);
|
||||
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
|
||||
fclose(fp_read);
|
||||
|
||||
if (read != llama_state_set_data(ctx2, state_mem.data(), state_mem.size())) {
|
||||
fprintf(stderr, "\n%s : failed to read state\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size());
|
||||
if (!llama_state_load_file(ctx2, state_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) {
|
||||
fprintf(stderr, "\n%s : failed to load state\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s : loaded state with %zu tokens\n", __func__, n_token_count_out);
|
||||
|
||||
// restore state (last tokens)
|
||||
n_past = n_past_saved;
|
||||
n_past = n_token_count_out;
|
||||
if (!common_replay_last_token(ctx2, tokens.back(), n_past)) {
|
||||
return 1;
|
||||
}
|
||||
++n_past;
|
||||
|
||||
// second run
|
||||
for (auto i = 0; i < params.n_predict; i++) {
|
||||
|
|
@ -160,7 +140,9 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// make new context
|
||||
llama_context * ctx3 = llama_init_from_model(model, common_context_params_to_llama(params));
|
||||
auto params_ctx3 = common_context_params_to_llama(params);
|
||||
params_ctx3.n_seq_max = 2;
|
||||
llama_context * ctx3 = llama_init_from_model(model, params_ctx3);
|
||||
|
||||
llama_sampler * smpl3 = llama_sampler_chain_init(sparams);
|
||||
|
||||
|
|
@ -169,26 +151,21 @@ int main(int argc, char ** argv) {
|
|||
printf("\nsingle seq run: %s", params.prompt.c_str());
|
||||
|
||||
// load state (rng, logits, embedding and kv_cache) from file
|
||||
{
|
||||
std::vector<uint8_t> state_mem;
|
||||
n_token_count_out = 0;
|
||||
|
||||
FILE * fp_read = fopen("dump_state.bin", "rb");
|
||||
fseek(fp_read, 0, SEEK_END);
|
||||
state_mem.resize(ftell(fp_read));
|
||||
fseek(fp_read, 0, SEEK_SET);
|
||||
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
|
||||
fclose(fp_read);
|
||||
|
||||
if (read != llama_state_set_data(ctx3, state_mem.data(), state_mem.size())) {
|
||||
fprintf(stderr, "\n%s : failed to read state\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size());
|
||||
if (!llama_state_load_file(ctx3, state_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) {
|
||||
fprintf(stderr, "\n%s : failed to load state\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s : loaded state with %zu tokens\n", __func__, n_token_count_out);
|
||||
|
||||
// restore state (last tokens)
|
||||
n_past = n_past_saved;
|
||||
n_past = n_token_count_out;
|
||||
if (!common_replay_last_token(ctx3, tokens.back(), n_past)) {
|
||||
return 1;
|
||||
}
|
||||
++n_past;
|
||||
|
||||
// save seq 0 and load into seq 1
|
||||
{
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@
|
|||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
|
|
@ -55,9 +56,10 @@
|
|||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
|
|
@ -77,6 +79,7 @@
|
|||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
|
|
@ -86,6 +89,7 @@
|
|||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
|
|
@ -110,6 +114,7 @@
|
|||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
|
|
@ -123,6 +128,7 @@
|
|||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
|
|
@ -148,6 +154,7 @@
|
|||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
|
|
@ -161,6 +168,7 @@
|
|||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
|
|
@ -187,6 +195,7 @@
|
|||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
|
|
@ -199,6 +208,7 @@
|
|||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
|
|
@ -230,6 +240,7 @@
|
|||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
|
|
@ -243,6 +254,7 @@
|
|||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
|
|
@ -276,6 +288,7 @@
|
|||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
|
|
@ -289,6 +302,7 @@
|
|||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
|
|
|
|||
|
|
@ -785,6 +785,165 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
|
|||
ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_q5_K_8x4_q8_K(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
constexpr int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
|
||||
constexpr int ncols_interleaved = 8;
|
||||
constexpr int blocklen = 4;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
constexpr int col_groups = ncols_interleaved / 4; // 0123 and 4567
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||
const uint8x16_t mone = vdupq_n_u8(1);
|
||||
const uint8x16_t mtwo = vdupq_n_u8(2);
|
||||
|
||||
// 1x8 tile = 2 x 4
|
||||
float32x4_t acc_f32[col_groups];
|
||||
|
||||
const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
|
||||
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int i = 0; i < col_groups; i++) {
|
||||
acc_f32[i] = vdupq_n_f32(0);
|
||||
}
|
||||
|
||||
for (int b = 0; b < nb; b++) {
|
||||
float32x4_t q5_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d)); // d0 d1 d2 d3
|
||||
float32x4_t q5_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4)); // d4 d5 d6 d7
|
||||
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
|
||||
float32x4_t sb_scale_0123 = vmulq_f32(q5_d_0, q8_d);
|
||||
float32x4_t sb_scale_4567 = vmulq_f32(q5_d_1, q8_d);
|
||||
float32x4_t q5_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin)); // dmin 0..3
|
||||
float32x4_t q5_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4)); // dmin 4..7
|
||||
float32x4_t sb_min_0123 = vmulq_f32(q5_dmin_0, q8_d);
|
||||
float32x4_t sb_min_4567 = vmulq_f32(q5_dmin_1, q8_d);
|
||||
|
||||
// interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
|
||||
int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
|
||||
int32x4_t acc_lo[col_groups];
|
||||
int32x4_t acc_hi[col_groups];
|
||||
|
||||
// Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
|
||||
const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
|
||||
int16_t bsums_arr[8];
|
||||
vst1q_s16(bsums_arr, bsums);
|
||||
|
||||
uint8x16_t qh[col_groups][8];
|
||||
for (int c = 0; c < col_groups; c++) {
|
||||
for (int i = 0; i < 8; i++) {
|
||||
qh[c][i] = vld1q_u8(q5_ptr[b].qh + i * 32 + 16 * c);
|
||||
}
|
||||
}
|
||||
|
||||
for (int sb = 0; sb < QK_K / 64; sb++) {
|
||||
for (int i = 0; i < col_groups; i++) {
|
||||
acc_lo[i] = vdupq_n_s32(0);
|
||||
acc_hi[i] = vdupq_n_s32(0);
|
||||
}
|
||||
// Need scales for the low and high nibbles
|
||||
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
||||
int16x8_t q5sb_mins[2];
|
||||
int16x8_t q5sb_scales[2];
|
||||
for (int i = 0; i < 2; i++) {
|
||||
int8_t aux_q5sb[8];
|
||||
const int offset = sb * 24 + i * 12;
|
||||
decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb);
|
||||
q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb));
|
||||
}
|
||||
|
||||
int8x16_t q8_qs[4];
|
||||
for (int i = 0; i < 4; i++) {
|
||||
q8_qs[i] = vld1q_s8(q8_ptr[b].qs + sb * 64 + i * 16);
|
||||
}
|
||||
|
||||
for (int c = 0; c < col_groups; c++) {
|
||||
uint8x16_t q5_cols[8];
|
||||
uint8x16_t hbit_lo[8];
|
||||
uint8x16_t hbit_hi[8];
|
||||
int8x16_t q5_lo[8];
|
||||
int8x16_t q5_hi[8];
|
||||
|
||||
for (int i = 0; i < 8; i++) {
|
||||
q5_cols[i] = vld1q_u8(q5_ptr[b].qs + sb * QK_K + i * 32 + 16 * c);
|
||||
hbit_lo[i] = vandq_u8(qh[c][i], mone);
|
||||
hbit_hi[i] = vshlq_n_u8(vandq_u8(qh[c][i], mtwo), 3);
|
||||
qh[c][i] = vshrq_n_u8(qh[c][i], 2);
|
||||
q5_lo[i] = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_cols[i], m4b), hbit_lo[i], 4));
|
||||
q5_hi[i] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_cols[i], 4), hbit_hi[i]));
|
||||
}
|
||||
|
||||
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[0], q8_qs[0], 0);
|
||||
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[1], q8_qs[0], 1);
|
||||
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[2], q8_qs[0], 2);
|
||||
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[3], q8_qs[0], 3);
|
||||
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[4], q8_qs[1], 0);
|
||||
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[5], q8_qs[1], 1);
|
||||
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[6], q8_qs[1], 2);
|
||||
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[7], q8_qs[1], 3);
|
||||
|
||||
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[0], q8_qs[2], 0);
|
||||
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[1], q8_qs[2], 1);
|
||||
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[2], q8_qs[2], 2);
|
||||
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[3], q8_qs[2], 3);
|
||||
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[4], q8_qs[3], 0);
|
||||
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[5], q8_qs[3], 1);
|
||||
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[6], q8_qs[3], 2);
|
||||
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[7], q8_qs[3], 3);
|
||||
}
|
||||
|
||||
// Scales
|
||||
// row c0123 blk0 and blk1
|
||||
const int16x4_t sc_0123_lo = vget_low_s16(q5sb_scales[0]);
|
||||
const int16x4_t sc_0123_hi = vget_low_s16(q5sb_scales[1]);
|
||||
const float32x4_t sumf_0123 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[0]),
|
||||
vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[0])));
|
||||
acc_f32[0] = vfmaq_f32(acc_f32[0], sb_scale_0123, sumf_0123);
|
||||
// row c4567 blk0 and blk1
|
||||
const int16x4_t sc_4567_lo = vget_high_s16(q5sb_scales[0]);
|
||||
const int16x4_t sc_4567_hi = vget_high_s16(q5sb_scales[1]);
|
||||
const float32x4_t sumf_4567 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[1]),
|
||||
vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[1])));
|
||||
acc_f32[1] = vfmaq_f32(acc_f32[1], sb_scale_4567, sumf_4567);
|
||||
|
||||
// Bias Correction
|
||||
const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
|
||||
const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
|
||||
|
||||
bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q5sb_mins[0]));
|
||||
bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q5sb_mins[1]));
|
||||
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q5sb_mins[0]));
|
||||
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q5sb_mins[1]));
|
||||
} // for sb
|
||||
|
||||
acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0123);
|
||||
acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_4567);
|
||||
} // for b
|
||||
|
||||
int base = x * ncols_interleaved;
|
||||
vst1q_f32(s + base, acc_f32[0]);
|
||||
vst1q_f32(s + base + 4, acc_f32[1]);
|
||||
} // for x
|
||||
return;
|
||||
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
ggml_gemv_q5_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_q5_K_8x8_q8_K(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
|
|
@ -3205,6 +3364,235 @@ void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
ggml_gemm_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_q5_K_8x4_q8_K(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
constexpr int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
|
||||
constexpr int ncols_interleaved = 8;
|
||||
constexpr int blocklen = 4;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
constexpr int q8_k_blocklen = 4;
|
||||
constexpr int acc_size = 2 * 4; // 2 row pairs, 4 col pairs
|
||||
constexpr int col_groups = ncols_interleaved / 4;
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||
const uint8x16_t mone = vdupq_n_u8(1);
|
||||
const uint8x16_t mtwo = vdupq_n_u8(2);
|
||||
|
||||
// 8 accumulators: 2 row pairs, 4 col pairs
|
||||
float32x4_t acc_f32[acc_size];
|
||||
|
||||
for (int y = 0; y < nr / q8_k_blocklen; y++) {
|
||||
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int i = 0; i < acc_size; i++) {
|
||||
acc_f32[i] = vdupq_n_f32(0);
|
||||
}
|
||||
|
||||
for (int b = 0; b < nb; b++) {
|
||||
// d5 0 1 2 3, 4 5 6 7
|
||||
float32x4_t q5_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d));
|
||||
float32x4_t q5_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4));
|
||||
// d8 0 1 2 3
|
||||
float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);
|
||||
// mins
|
||||
float32x4_t q5_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin));
|
||||
float32x4_t q5_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4));
|
||||
|
||||
// Precomputation of scales and mins
|
||||
float32x4_t sbd_scale_0123[q8_k_blocklen];
|
||||
float32x4_t sbd_scale_4567[q8_k_blocklen];
|
||||
float32x4_t sbd_min_0123[q8_k_blocklen];
|
||||
float32x4_t sbd_min_4567[q8_k_blocklen];
|
||||
|
||||
sbd_scale_0123[0] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 0);
|
||||
sbd_scale_4567[0] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 0);
|
||||
sbd_min_0123[0] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 0);
|
||||
sbd_min_4567[0] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 0);
|
||||
|
||||
sbd_scale_0123[1] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 1);
|
||||
sbd_scale_4567[1] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 1);
|
||||
sbd_min_0123[1] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 1);
|
||||
sbd_min_4567[1] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 1);
|
||||
|
||||
sbd_scale_0123[2] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 2);
|
||||
sbd_scale_4567[2] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 2);
|
||||
sbd_min_0123[2] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 2);
|
||||
sbd_min_4567[2] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 2);
|
||||
|
||||
sbd_scale_0123[3] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 3);
|
||||
sbd_scale_4567[3] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 3);
|
||||
sbd_min_0123[3] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 3);
|
||||
sbd_min_4567[3] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 3);
|
||||
|
||||
// Precomputation of bsums, each vpaddq calcs all the bsums for each row
|
||||
const int16x8_t bsums[q8_k_blocklen] = {
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
|
||||
};
|
||||
int16_t bsums_arr[QK_K / 64][8];
|
||||
for (int q8_row = 0; q8_row < 4; q8_row++) {
|
||||
vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
|
||||
}
|
||||
|
||||
// interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 ..
|
||||
int32x4_t bias_acc[acc_size];
|
||||
for (int i = 0; i < acc_size; i++) {
|
||||
bias_acc[i] = vdupq_n_s32(0);
|
||||
}
|
||||
|
||||
uint8x16_t qh[col_groups][8];
|
||||
for (int c = 0; c < col_groups; c++) {
|
||||
for (int i = 0; i < 8; i++) {
|
||||
qh[c][i] = vld1q_u8(q5_ptr[b].qh + i * 32 + 16 * c);
|
||||
}
|
||||
}
|
||||
|
||||
for (int sb = 0; sb < QK_K / 64; sb++) {
|
||||
// Int accumulators for qs vecdot (4 row * 2 col quartets)
|
||||
int32x4_t acc_lo[acc_size];
|
||||
int32x4_t acc_hi[acc_size];
|
||||
for (int i = 0; i < acc_size; i++) {
|
||||
acc_lo[i] = vdupq_n_s32(0);
|
||||
acc_hi[i] = vdupq_n_s32(0);
|
||||
}
|
||||
// Need scales for the low and high nibbles
|
||||
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
||||
int16x8_t q5sb_scales[2];
|
||||
int16x8_t q5sb_mins[2];
|
||||
for (int i = 0; i < 2; i++) {
|
||||
int8_t aux_q5sb[8];
|
||||
const int offset = sb * 24 + i * 12;
|
||||
decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb);
|
||||
q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb));
|
||||
}
|
||||
|
||||
constexpr int reads_per_sb = 8; // 8 * 16 bytes each => 32 qs * 4 rows
|
||||
for (int k = 0; k < reads_per_sb; k++) {
|
||||
const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k);
|
||||
const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128);
|
||||
|
||||
// 0..3 & 32..35
|
||||
const uint8x16_t q5_0123 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 32 * k);
|
||||
const uint8x16_t q5_4567 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 32 * k + 16);
|
||||
|
||||
// NOTE: This is the only difference with q4_K
|
||||
const uint8x16_t hbit_lo_0123 = vandq_u8(qh[0][k], mone);
|
||||
const uint8x16_t hbit_hi_0123 = vshlq_n_u8(vandq_u8(qh[0][k], mtwo), 3);
|
||||
qh[0][k] = vshrq_n_u8(qh[0][k], 2);
|
||||
const uint8x16_t hbit_lo_4567 = vandq_u8(qh[1][k], mone);
|
||||
const uint8x16_t hbit_hi_4567 = vshlq_n_u8(vandq_u8(qh[1][k], mtwo), 3);
|
||||
qh[1][k] = vshrq_n_u8(qh[1][k], 2);
|
||||
// From here, same as q4_K
|
||||
|
||||
const int8x16_t q5_0123_lo =
|
||||
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_0123, m4b), hbit_lo_0123, 4));
|
||||
const int8x16_t q5_0123_hi =
|
||||
vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_0123, 4), hbit_hi_0123));
|
||||
|
||||
acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q5_0123_lo, q8_blk0, 0); // 0..3 r0 c0123
|
||||
acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q5_0123_lo, q8_blk0, 1); // 0..3 r1 c0123
|
||||
acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q5_0123_lo, q8_blk0, 2); // 0..3 r2 c0123
|
||||
acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q5_0123_lo, q8_blk0, 3); // 0..3 r3 c0123
|
||||
|
||||
acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q5_0123_hi, q8_blk1, 0); // 32..35 r0 c0123
|
||||
acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q5_0123_hi, q8_blk1, 1); // 32..35 r1 c0123
|
||||
acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q5_0123_hi, q8_blk1, 2); // 32..35 r2 c0123
|
||||
acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q5_0123_hi, q8_blk1, 3); // 32..35 r3 c0123
|
||||
|
||||
const int8x16_t q5_4567_lo =
|
||||
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_4567, m4b), hbit_lo_4567, 4));
|
||||
const int8x16_t q5_4567_hi =
|
||||
vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_4567, 4), hbit_hi_4567));
|
||||
|
||||
acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q5_4567_lo, q8_blk0, 0); // 0..3 r0 c4567
|
||||
acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q5_4567_lo, q8_blk0, 1); // 0..3 r1 c4567
|
||||
acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q5_4567_lo, q8_blk0, 2); // 0..3 r2 c4567
|
||||
acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q5_4567_lo, q8_blk0, 3); // 0..3 r3 c4567
|
||||
|
||||
acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q5_4567_hi, q8_blk1, 0); // 32..35 r0 c4567
|
||||
acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q5_4567_hi, q8_blk1, 1); // 32..35 r1 c4567
|
||||
acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q5_4567_hi, q8_blk1, 2); // 32..35 r2 c4567
|
||||
acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q5_4567_hi, q8_blk1, 3); // 32..35 r3 c4567
|
||||
}
|
||||
|
||||
// Scale and bias application
|
||||
// acc is stored interleaved to match output layout
|
||||
const int16x4_t sc_0123_lo = vget_low_s16(q5sb_scales[0]);
|
||||
const int16x4_t sc_4567_lo = vget_high_s16(q5sb_scales[0]);
|
||||
const int16x4_t sc_0123_hi = vget_low_s16(q5sb_scales[1]);
|
||||
const int16x4_t sc_4567_hi = vget_high_s16(q5sb_scales[1]);
|
||||
for (int row = 0; row < q8_k_blocklen; row++) {
|
||||
// Bias correction
|
||||
// row c0123 blk0 and blk1
|
||||
const float32x4_t sumf_0123 =
|
||||
vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]),
|
||||
vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row])));
|
||||
acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123);
|
||||
|
||||
// row c4567 blk0 and blk1
|
||||
const float32x4_t sumf_4567 =
|
||||
vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]),
|
||||
vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4])));
|
||||
acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567);
|
||||
|
||||
// Bias
|
||||
const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]);
|
||||
const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]);
|
||||
|
||||
// row c0123 blk0 and blk1
|
||||
bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q5sb_mins[0]));
|
||||
bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q5sb_mins[1]));
|
||||
|
||||
// row c4567 blk0 and blk1
|
||||
bias_acc[2 * row + 1] =
|
||||
vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q5sb_mins[0]));
|
||||
bias_acc[2 * row + 1] =
|
||||
vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q5sb_mins[1]));
|
||||
}
|
||||
} // for sb
|
||||
|
||||
for (int row = 0; row < q8_k_blocklen; row++) {
|
||||
acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]);
|
||||
acc_f32[2 * row + 1] =
|
||||
vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]);
|
||||
}
|
||||
} // for b
|
||||
|
||||
for (int i = 0; i < q8_k_blocklen; i++) {
|
||||
int row = y * q8_k_blocklen + i;
|
||||
for (int j = 0; j < 2; j++) {
|
||||
int col = x * ncols_interleaved + j * 4;
|
||||
int offset = row * bs + col;
|
||||
vst1q_f32(s + offset, acc_f32[2 * i + j]);
|
||||
}
|
||||
}
|
||||
} // for x
|
||||
} // for y
|
||||
return;
|
||||
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
ggml_gemm_q5_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_q4_K_8x8_q8_K(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
|
|
|
|||
|
|
@ -450,6 +450,208 @@ static void ggml_gemm_q6_K_NxM_q8_K_generic_impl(int n,
|
|||
}
|
||||
}
|
||||
|
||||
template <int M, int N>
|
||||
static void ggml_gemv_q5_K_NxM_q8_K_generic_impl(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
constexpr int blocklen = M;
|
||||
constexpr int ncols_interleaved = N;
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
static const uint32_t kmask1 = 0x3f3f3f3f;
|
||||
static const uint32_t kmask2 = 0x0f0f0f0f;
|
||||
static const uint32_t kmask3 = 0x03030303;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(bs);
|
||||
UNUSED(nr);
|
||||
|
||||
float sumf[ncols_interleaved];
|
||||
float sum_minf[ncols_interleaved];
|
||||
uint32_t utmp[32];
|
||||
int sumi1;
|
||||
int sumi2;
|
||||
int sumi;
|
||||
|
||||
const block_q8_K * a_ptr = (const block_q8_K *) vy;
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[j] = 0.0;
|
||||
sum_minf[j] = 0.0;
|
||||
}
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * K_SCALE_SIZE, K_SCALE_SIZE);
|
||||
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
|
||||
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
|
||||
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
|
||||
utmp[sb * 4 + 2] = uaux_0;
|
||||
utmp[sb * 4 + 0] &= kmask1;
|
||||
}
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
constexpr int scale_stride = 32;
|
||||
uint8_t * scales_0 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride;
|
||||
uint8_t * scales_1 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride + 16;
|
||||
|
||||
const int qh_shift = (k / (32 / blocklen)) * 2;
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi1 = 0;
|
||||
sumi2 = 0;
|
||||
sumi = 0;
|
||||
for (int i = 0; i < blocklen; ++i) {
|
||||
const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
|
||||
|
||||
const int qh_idx = (k * blocklen + i) % 32;
|
||||
const int qh_chunk = qh_idx / blocklen;
|
||||
const int qh_pos = qh_idx % blocklen;
|
||||
const int b_qh_offset = qh_chunk * (blocklen * ncols_interleaved) + j * blocklen + qh_pos;
|
||||
|
||||
const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
|
||||
const uint8_t h0 = (qh_val >> qh_shift) & 1;
|
||||
const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1;
|
||||
|
||||
const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
|
||||
const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
|
||||
|
||||
const int q8_offset = (k / (32 / blocklen)) * 64 + (k % (32 / blocklen)) * blocklen + i;
|
||||
|
||||
sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
|
||||
sumi2 = (v1 * a_ptr[l].qs[q8_offset + 32]);
|
||||
sumi1 = sumi1 * scales_0[j];
|
||||
sumi2 = sumi2 * scales_1[j];
|
||||
sumi += sumi1 + sumi2;
|
||||
}
|
||||
sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
|
||||
}
|
||||
}
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) *
|
||||
GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int M, int N>
|
||||
static void ggml_gemm_q5_K_NxM_q8_K_generic_impl(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
constexpr int blocklen = M;
|
||||
constexpr int ncols_interleaved = N;
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
static const uint32_t kmask1 = 0x3f3f3f3f;
|
||||
static const uint32_t kmask2 = 0x0f0f0f0f;
|
||||
static const uint32_t kmask3 = 0x03030303;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
float sumf[4][ncols_interleaved];
|
||||
float sum_minf[4][ncols_interleaved];
|
||||
uint32_t utmp[32];
|
||||
int sumi1;
|
||||
int sumi2;
|
||||
int sumi;
|
||||
|
||||
for (int y = 0; y < nr / 4; y++) {
|
||||
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[m][j] = 0.0;
|
||||
sum_minf[m][j] = 0.0;
|
||||
}
|
||||
}
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * K_SCALE_SIZE, K_SCALE_SIZE);
|
||||
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
|
||||
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
|
||||
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
|
||||
utmp[sb * 4 + 2] = uaux_0;
|
||||
utmp[sb * 4 + 0] &= kmask1;
|
||||
}
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
constexpr int scale_stride = 32;
|
||||
uint8_t * scales_0 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride;
|
||||
uint8_t * scales_1 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride + 16;
|
||||
|
||||
const int qh_shift = (k / (32 / blocklen)) * 2;
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi1 = 0;
|
||||
sumi2 = 0;
|
||||
sumi = 0;
|
||||
for (int i = 0; i < blocklen; ++i) {
|
||||
const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
|
||||
|
||||
const int qh_idx = (k * blocklen + i) % 32;
|
||||
const int qh_chunk = qh_idx / blocklen;
|
||||
const int qh_pos = qh_idx % blocklen;
|
||||
const int b_qh_offset =
|
||||
qh_chunk * (blocklen * ncols_interleaved) + j * blocklen + qh_pos;
|
||||
|
||||
const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
|
||||
const uint8_t h0 = (qh_val >> qh_shift) & 1;
|
||||
const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1;
|
||||
|
||||
const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
|
||||
const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
|
||||
|
||||
const int q8_offset = (k / (32 / blocklen)) * 256 +
|
||||
(k % (32 / blocklen)) * 4 * blocklen + m * blocklen + i;
|
||||
|
||||
sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
|
||||
sumi2 = (v1 * a_ptr[l].qs[q8_offset + 128]);
|
||||
sumi1 = sumi1 * scales_0[j];
|
||||
sumi2 = sumi2 * scales_1[j];
|
||||
sumi += sumi1 + sumi2;
|
||||
}
|
||||
sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
|
||||
for (int m = 0; m < 4; m++) {
|
||||
const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) *
|
||||
GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
|
|
@ -803,98 +1005,12 @@ void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
|||
}
|
||||
}
|
||||
|
||||
void ggml_gemv_q5_K_8x8_q8_K_generic(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 8;
|
||||
const int blocklen = 8;
|
||||
static const uint32_t kmask1 = 0x3f3f3f3f;
|
||||
static const uint32_t kmask2 = 0x0f0f0f0f;
|
||||
static const uint32_t kmask3 = 0x03030303;
|
||||
void ggml_gemv_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
ggml_gemv_q5_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(bs);
|
||||
UNUSED(nr);
|
||||
|
||||
float sumf[8];
|
||||
float sum_minf[8];
|
||||
uint32_t utmp[32];
|
||||
int sumi1;
|
||||
int sumi2;
|
||||
int sumi;
|
||||
|
||||
const block_q8_K * a_ptr = (const block_q8_K *) vy;
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[j] = 0.0;
|
||||
sum_minf[j] = 0.0;
|
||||
}
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
|
||||
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
|
||||
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
|
||||
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
|
||||
utmp[sb * 4 + 2] = uaux_0;
|
||||
utmp[sb * 4 + 0] &= kmask1;
|
||||
}
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32;
|
||||
uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16;
|
||||
|
||||
const int qh_shift = (k / 4) * 2;
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi1 = 0;
|
||||
sumi2 = 0;
|
||||
sumi = 0;
|
||||
for (int i = 0; i < blocklen; ++i) {
|
||||
const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
|
||||
|
||||
const int qh_idx = (k * 8 + i) % 32;
|
||||
const int qh_chunk = qh_idx / 8;
|
||||
const int qh_pos = qh_idx % 8;
|
||||
const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos;
|
||||
|
||||
const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
|
||||
const uint8_t h0 = (qh_val >> qh_shift) & 1;
|
||||
const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1;
|
||||
|
||||
const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
|
||||
const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
|
||||
|
||||
const int q8_offset = (k >> 2) * 64 + (k % 4) * blocklen + i;
|
||||
|
||||
sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
|
||||
sumi2 = (v1 * a_ptr[l].qs[q8_offset + 32]);
|
||||
sumi1 = sumi1 * scales_0[j];
|
||||
sumi2 = sumi2 * scales_1[j];
|
||||
sumi += sumi1 + sumi2;
|
||||
}
|
||||
sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
|
||||
}
|
||||
}
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) *
|
||||
GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
|
||||
}
|
||||
}
|
||||
void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
ggml_gemv_q5_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -1494,107 +1610,12 @@ void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
|||
}
|
||||
}
|
||||
|
||||
void ggml_gemm_q5_K_8x8_q8_K_generic(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 8;
|
||||
const int blocklen = 8;
|
||||
void ggml_gemm_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
ggml_gemm_q5_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
constexpr uint32_t kmask1 = 0x3f3f3f3f;
|
||||
constexpr uint32_t kmask2 = 0x0f0f0f0f;
|
||||
constexpr uint32_t kmask3 = 0x03030303;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
float sumf[4][8];
|
||||
float sum_minf[4][8];
|
||||
uint32_t utmp[32];
|
||||
int sumi1;
|
||||
int sumi2;
|
||||
int sumi;
|
||||
|
||||
for (int y = 0; y < nr / 4; y++) {
|
||||
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[m][j] = 0.0;
|
||||
sum_minf[m][j] = 0.0;
|
||||
}
|
||||
}
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
|
||||
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
|
||||
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
|
||||
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
|
||||
utmp[sb * 4 + 2] = uaux_0;
|
||||
utmp[sb * 4 + 0] &= kmask1;
|
||||
}
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32;
|
||||
uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16;
|
||||
|
||||
const int qh_shift = (k / 4) * 2;
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi1 = 0;
|
||||
sumi2 = 0;
|
||||
sumi = 0;
|
||||
for (int i = 0; i < blocklen; ++i) {
|
||||
const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
|
||||
|
||||
const int qh_idx = (k * 8 + i) % 32;
|
||||
const int qh_chunk = qh_idx / 8;
|
||||
const int qh_pos = qh_idx % 8;
|
||||
const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos;
|
||||
|
||||
const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
|
||||
const uint8_t h0 = (qh_val >> qh_shift) & 1;
|
||||
const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1;
|
||||
|
||||
const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
|
||||
const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
|
||||
|
||||
const int q8_offset = (k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i;
|
||||
|
||||
sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
|
||||
sumi2 = (v1 * a_ptr[l].qs[q8_offset + 128]);
|
||||
sumi1 = sumi1 * scales_0[j];
|
||||
sumi2 = sumi2 * scales_1[j];
|
||||
sumi += sumi1 + sumi2;
|
||||
}
|
||||
sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
|
||||
for (int m = 0; m < 4; m++) {
|
||||
const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) *
|
||||
GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
ggml_gemm_q5_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
|
|
@ -2029,18 +2050,16 @@ static block_q5_Kx8 make_block_q5_Kx8(block_q5_K * in, unsigned int blck_size_in
|
|||
|
||||
const int end = QK_K * 4 / blck_size_interleave;
|
||||
|
||||
// Interleave Q5_K quants by taking 8 bytes at a time
|
||||
// Interleave Q5_K quants by taking blck_size_interleave bytes at a time
|
||||
for (int i = 0; i < end; ++i) {
|
||||
int src_id = i % 8;
|
||||
int src_offset = (i / 8) * blck_size_interleave;
|
||||
int dst_offset = i * blck_size_interleave;
|
||||
|
||||
uint64_t elems;
|
||||
memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
|
||||
memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
|
||||
memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], blck_size_interleave);
|
||||
}
|
||||
|
||||
// Repeat for low bits 8 bytes at a time as well, since
|
||||
// Repeat for high bits with the same chunk size, since
|
||||
// the high bits are interleaved in Q5_K and the index is
|
||||
// qh_idx = (qs_idx % 32);
|
||||
// qh_val = qh[qh_idx] >> (qs_idx / 32);
|
||||
|
|
@ -2049,9 +2068,7 @@ static block_q5_Kx8 make_block_q5_Kx8(block_q5_K * in, unsigned int blck_size_in
|
|||
int src_offset = (i / 8) * blck_size_interleave;
|
||||
int dst_offset = i * blck_size_interleave;
|
||||
|
||||
uint64_t elems;
|
||||
memcpy(&elems, &in[src_id].qh[src_offset], sizeof(uint64_t));
|
||||
memcpy(&out.qh[dst_offset], &elems, sizeof(uint64_t));
|
||||
memcpy(&out.qh[dst_offset], &in[src_id].qh[src_offset], blck_size_interleave);
|
||||
}
|
||||
|
||||
// The below logic is copied over from Q4_K
|
||||
|
|
@ -2249,7 +2266,7 @@ static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor * t,
|
|||
const void * GGML_RESTRICT data,
|
||||
size_t data_size) {
|
||||
GGML_ASSERT(t->type == GGML_TYPE_Q5_K);
|
||||
GGML_ASSERT(interleave_block == 8);
|
||||
GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
|
||||
constexpr int nrows_interleaved = 8;
|
||||
|
||||
block_q5_Kx8 * dst = (block_q5_Kx8 *) t->data;
|
||||
|
|
@ -2523,6 +2540,10 @@ template <> int repack<block_q2_K, 8, 8>(struct ggml_tensor * t, const void * da
|
|||
return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size);
|
||||
}
|
||||
|
||||
template <> int repack<block_q5_K, 4, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||
return repack_q5_K_to_q5_K_8_bl(t, 4, data, data_size);
|
||||
}
|
||||
|
||||
template <> int repack<block_q5_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||
return repack_q5_K_to_q5_K_8_bl(t, 8, data, data_size);
|
||||
}
|
||||
|
|
@ -2591,6 +2612,10 @@ template <> void gemv<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t
|
|||
ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_q5_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_q5_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
|
@ -2654,6 +2679,10 @@ template <> void gemm<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t
|
|||
ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q5_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q5_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
|
@ -3068,6 +3097,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
|||
static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
|
||||
|
||||
// instance for Q5_K
|
||||
static const ggml::cpu::repack::tensor_traits<block_q5_K, 4, 8, GGML_TYPE_Q8_K> q5_K_8x4_q8_K;
|
||||
static const ggml::cpu::repack::tensor_traits<block_q5_K, 8, 8, GGML_TYPE_Q8_K> q5_K_8x8_q8_K;
|
||||
|
||||
// instance for Q6_K
|
||||
|
|
@ -3130,6 +3160,11 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
|||
return &q5_K_8x8_q8_K;
|
||||
}
|
||||
}
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
||||
if (cur->ne[1] % 8 == 0) {
|
||||
return &q5_K_8x4_q8_K;
|
||||
}
|
||||
}
|
||||
} else if (cur->type == GGML_TYPE_Q6_K) {
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
|
||||
if (cur->ne[1] % 8 == 0) {
|
||||
|
|
|
|||
|
|
@ -111,6 +111,7 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
|
|
@ -122,6 +123,7 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
|
|
@ -143,6 +145,7 @@ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
|||
void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
|
|
@ -154,6 +157,7 @@ void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
|||
void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
|
|
|
|||
|
|
@ -1149,8 +1149,7 @@ struct ggml_cuda_graph {
|
|||
size_t num_nodes = 0;
|
||||
std::vector<cudaGraphNode_t> nodes;
|
||||
bool disable_due_to_gpu_arch = false;
|
||||
bool disable_due_to_too_many_updates = false;
|
||||
int number_consecutive_updates = 0;
|
||||
bool warmup_complete = false;
|
||||
std::vector<ggml_cuda_graph_node_properties> props;
|
||||
|
||||
// these are extra tensors (inputs) that participate in the ggml graph but are not nodes
|
||||
|
|
@ -1159,21 +1158,9 @@ struct ggml_cuda_graph {
|
|||
// ref: https://github.com/ggml-org/llama.cpp/pull/19165
|
||||
std::vector<ggml_cuda_graph_node_properties> extra;
|
||||
|
||||
void record_update(bool use_graph, bool update_required) {
|
||||
if (use_graph && update_required) {
|
||||
number_consecutive_updates++;
|
||||
} else {
|
||||
number_consecutive_updates = 0;
|
||||
}
|
||||
if (number_consecutive_updates >= 4) {
|
||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
|
||||
disable_due_to_too_many_updates = true;
|
||||
}
|
||||
}
|
||||
|
||||
bool is_enabled() const {
|
||||
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
|
||||
return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env || disable_due_to_too_many_updates);
|
||||
return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
|
|
|||
|
|
@ -2979,10 +2979,6 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
|
|||
const void * graph_key = ggml_cuda_graph_get_key(cgraph);
|
||||
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
|
||||
|
||||
if (graph->instance == nullptr) {
|
||||
res = true;
|
||||
}
|
||||
|
||||
// Check if the graph size has changed
|
||||
if (graph->props.size() != (size_t)cgraph->n_nodes) {
|
||||
res = true;
|
||||
|
|
@ -3931,14 +3927,35 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
|||
#ifdef USE_CUDA_GRAPH
|
||||
graph_key = ggml_cuda_graph_get_key(cgraph);
|
||||
|
||||
use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
|
||||
ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
|
||||
|
||||
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
|
||||
if (graph->is_enabled()) {
|
||||
cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph);
|
||||
use_cuda_graph = ggml_cuda_graph_check_compability(cgraph);
|
||||
const bool graph_compatible = ggml_cuda_graph_check_compability(cgraph);
|
||||
if (graph_compatible) {
|
||||
const bool properties_changed = ggml_cuda_graph_update_required(cuda_ctx, cgraph);
|
||||
|
||||
graph->record_update(use_cuda_graph, cuda_graph_update_required);
|
||||
if (!graph->warmup_complete) {
|
||||
// Warmup: need at least 2 calls with no property change on the 2nd call
|
||||
if (!properties_changed) {
|
||||
graph->warmup_complete = true;
|
||||
GGML_LOG_DEBUG("%s: CUDA graph warmup complete\n", __func__);
|
||||
use_cuda_graph = true;
|
||||
cuda_graph_update_required = true;
|
||||
}
|
||||
// else: properties changed or first call - execute directly (use_cuda_graph stays false)
|
||||
} else {
|
||||
// Post-warmup: normal CUDA graph operation
|
||||
if (properties_changed) {
|
||||
// Properties changed - reset warmup, execute directly until stable again
|
||||
graph->warmup_complete = false;
|
||||
GGML_LOG_DEBUG("%s: CUDA graph warmup reset\n", __func__);
|
||||
} else {
|
||||
use_cuda_graph = true;
|
||||
cuda_graph_update_required = graph->instance == nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // USE_CUDA_GRAPH
|
||||
|
||||
|
|
|
|||
|
|
@ -1749,23 +1749,6 @@ static inline bool ggml_backend_buffer_is_hexagon_repack(const struct ggml_backe
|
|||
return b->buft->iface.alloc_buffer == ggml_backend_hexagon_repack_buffer_type_alloc_buffer;
|
||||
}
|
||||
|
||||
static bool hex_supported_dims2(const struct ggml_tensor * x, const struct ggml_tensor * y) {
|
||||
if (x->ne[0] != y->ne[0]) {
|
||||
return false;
|
||||
}
|
||||
if (x->ne[1] != y->ne[1]) {
|
||||
return false;
|
||||
}
|
||||
if (x->ne[2] != y->ne[2]) {
|
||||
return false;
|
||||
}
|
||||
if (x->ne[3] != y->ne[3]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
const struct ggml_tensor * src1 = op->src[1];
|
||||
|
|
@ -1797,43 +1780,6 @@ static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_sess
|
|||
return opt_experimental;
|
||||
}
|
||||
|
||||
static bool hex_supported_src0_type(ggml_type t) {
|
||||
return t == GGML_TYPE_F32;
|
||||
}
|
||||
|
||||
static bool hex_supported_src1_type(ggml_type t) {
|
||||
return t == GGML_TYPE_F32;
|
||||
}
|
||||
|
||||
static bool hex_supported_src2_type(ggml_type t) {
|
||||
return t == GGML_TYPE_F32;
|
||||
}
|
||||
|
||||
static bool hex_supported_src1_type2(ggml_type t) {
|
||||
return t == GGML_TYPE_F16;
|
||||
}
|
||||
|
||||
static bool hex_supported_src1_type3(ggml_type t) {
|
||||
return t == GGML_TYPE_I32;
|
||||
}
|
||||
|
||||
static bool hex_supported_dst_type(ggml_type t) {
|
||||
return t == GGML_TYPE_F32;
|
||||
}
|
||||
|
||||
static bool hex_supported_dims(const struct ggml_tensor * x, const struct ggml_tensor * y) {
|
||||
// TODO: support broadcast for ne[2 and 3]
|
||||
if (x->ne[0] != y->ne[0]) {
|
||||
return false;
|
||||
}
|
||||
if (x->ne[2] != y->ne[2]) {
|
||||
return false;
|
||||
}
|
||||
if (x->ne[3] != y->ne[3]) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * dst) {
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
|
|
@ -1919,19 +1865,19 @@ static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * se
|
|||
const struct ggml_tensor * src1 = op->src[1];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
if (!hex_supported_src0_type(src0->type)) {
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_src1_type(src1->type)) {
|
||||
if (src1->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_dst_type(dst->type)) {
|
||||
if (dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_dims2(src0, dst)) {
|
||||
if (!ggml_are_same_shape(src0, dst)) {
|
||||
return false;
|
||||
}
|
||||
if (!ggml_can_repeat(src1, src0)) {
|
||||
if (!ggml_can_repeat(src1, src0) || ggml_is_permuted(src1)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -1943,16 +1889,16 @@ static bool ggml_hexagon_supported_add_id(const struct ggml_hexagon_session * se
|
|||
const struct ggml_tensor * src1 = op->src[1];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
if (!hex_supported_src0_type(src0->type)) {
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_src1_type(src1->type)) {
|
||||
if (src1->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_dst_type(dst->type)) {
|
||||
if (dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_dims2(src0, dst)) {
|
||||
if (!ggml_are_same_shape(src0, dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -1968,13 +1914,13 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses
|
|||
const struct ggml_tensor * src0 = op->src[0];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
if (!hex_supported_src0_type(src0->type)) {
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_dst_type(dst->type)) {
|
||||
if (dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_dims2(src0, dst)) {
|
||||
if (!ggml_are_same_shape(src0, dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -1990,10 +1936,10 @@ static bool ggml_hexagon_supported_sum_rows(const struct ggml_hexagon_session *
|
|||
const struct ggml_tensor * src0 = op->src[0];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
if (!hex_supported_src0_type(src0->type)) {
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_dst_type(dst->type)) {
|
||||
if (dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -2011,10 +1957,10 @@ static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session
|
|||
const struct ggml_tensor * src1 = op->src[1];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
if (!hex_supported_src0_type(src0->type)) {
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_dst_type(dst->type)) {
|
||||
if (dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -2023,10 +1969,10 @@ static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session
|
|||
}
|
||||
|
||||
if (src1) {
|
||||
if (!hex_supported_src1_type(src1->type)) {
|
||||
if (src1->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_dims2(src0, src1)) {
|
||||
if (!ggml_are_same_shape(src0, src1)) {
|
||||
return false;
|
||||
}
|
||||
if (!ggml_is_contiguous(src1)) {
|
||||
|
|
@ -2047,15 +1993,15 @@ static bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * s
|
|||
return false; // FIXME: add support for sinks
|
||||
}
|
||||
|
||||
if (!hex_supported_src0_type(src0->type)) {
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_dst_type(dst->type)) {
|
||||
if (dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src1) {
|
||||
if (!hex_supported_src1_type(src1->type) && !hex_supported_src1_type2(src1->type)) {
|
||||
if (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
if (src0->ne[0] != src1->ne[0]) {
|
||||
|
|
@ -2162,17 +2108,17 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess
|
|||
const struct ggml_tensor * src2 = op->src[2];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
if (!hex_supported_src0_type(src0->type)) {
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
return false; // FIXME: add support for GGML_TYPE_F16 for src0
|
||||
}
|
||||
if (!hex_supported_dst_type(dst->type)) {
|
||||
if (dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_src1_type3(src1->type)) {
|
||||
if (src1->type != GGML_TYPE_I32) {
|
||||
return false;
|
||||
}
|
||||
if (src2) {
|
||||
if (!hex_supported_src2_type(src2->type)) {
|
||||
if (src2->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
int n_dims = op_params[1];
|
||||
|
|
|
|||
|
|
@ -69,27 +69,45 @@
|
|||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0,
|
||||
const struct htp_tensor * src1,
|
||||
struct htp_tensor * dst,
|
||||
const int32_t * op_params,
|
||||
struct htp_spad * src0_spad,
|
||||
struct htp_spad * src1_spad,
|
||||
struct htp_spad * dst_spad,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread,
|
||||
dma_queue * dma_queue) {
|
||||
struct htp_act_context {
|
||||
struct htp_ops_context * octx;
|
||||
|
||||
// Precomputed values
|
||||
const uint8_t * data_src0;
|
||||
const uint8_t * data_src1;
|
||||
uint8_t * data_dst;
|
||||
|
||||
size_t src0_row_size;
|
||||
size_t src1_row_size;
|
||||
size_t dst_row_size;
|
||||
|
||||
size_t src0_row_size_aligned;
|
||||
size_t src1_row_size_aligned;
|
||||
size_t dst_row_size_aligned;
|
||||
|
||||
size_t src0_spad_half_size;
|
||||
size_t src1_spad_half_size;
|
||||
size_t dst_spad_half_size;
|
||||
|
||||
uint32_t block;
|
||||
uint32_t src0_nrows;
|
||||
uint32_t src0_nrows_per_thread;
|
||||
int nc;
|
||||
};
|
||||
|
||||
static void glu_swiglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_act_context * actx = (struct htp_act_context *) data;
|
||||
const struct htp_tensor * src0 = &actx->octx->src0;
|
||||
const struct htp_tensor * src1 = &actx->octx->src1;
|
||||
const struct htp_tensor * dst = &actx->octx->dst;
|
||||
htp_act_preamble3;
|
||||
|
||||
size_t src0_row_size = nb01;
|
||||
size_t src1_row_size = nb11;
|
||||
size_t dst_row_size = nb1;
|
||||
|
||||
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
size_t src0_row_size = actx->src0_row_size;
|
||||
size_t src1_row_size = actx->src1_row_size;
|
||||
size_t dst_row_size = actx->dst_row_size;
|
||||
|
||||
const uint32_t src0_nrows = actx->src0_nrows;
|
||||
const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
|
||||
|
|
@ -101,43 +119,34 @@ static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0,
|
|||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
|
||||
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
|
||||
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||
const uint8_t * restrict data_src0 = actx->data_src0;
|
||||
const uint8_t * restrict data_src1 = actx->data_src1;
|
||||
uint8_t * restrict data_dst = actx->data_dst;
|
||||
|
||||
const bool src1_valid = src1->ne[0];
|
||||
const int nc = (src1_valid) ? ne00 : ne00 / 2;
|
||||
if (!src1_valid) {
|
||||
const int32_t swapped = op_params[1];
|
||||
data_src1 = data_src0;
|
||||
src1_row_size = src0_row_size;
|
||||
const int nc = actx->nc;
|
||||
|
||||
const size_t nc_in_bytes = nc * SIZEOF_FP32;
|
||||
data_src0 += swapped ? nc_in_bytes : 0;
|
||||
data_src1 += swapped ? 0 : nc_in_bytes;
|
||||
}
|
||||
const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
|
||||
const size_t src1_row_size_aligned = actx->src1_row_size_aligned;
|
||||
const size_t dst_row_size_aligned = actx->dst_row_size_aligned;
|
||||
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
|
||||
uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread);
|
||||
uint8_t * restrict dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
|
||||
|
||||
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
|
||||
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
|
||||
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
|
||||
size_t src0_spad_half_size = actx->src0_spad_half_size;
|
||||
size_t src1_spad_half_size = actx->src1_spad_half_size;
|
||||
size_t dst_spad_half_size = actx->dst_spad_half_size;
|
||||
|
||||
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
|
||||
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
|
||||
size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
|
||||
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
|
||||
|
||||
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
|
||||
const int BLOCK = actx->block;
|
||||
if (BLOCK == 0) {
|
||||
FARF(ERROR,
|
||||
"swiglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
|
||||
src0_spad->size_per_thread, src0_row_size_aligned);
|
||||
actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
|
||||
return;
|
||||
}
|
||||
|
||||
dma_queue * dma_queue = actx->octx->ctx->dma[ith];
|
||||
|
||||
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
|
||||
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
|
|
@ -196,27 +205,22 @@ static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0,
|
|||
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void glu_swiglu_oai_f32_per_thread(const struct htp_tensor * src0,
|
||||
const struct htp_tensor * src1,
|
||||
struct htp_tensor * dst,
|
||||
const int32_t * op_params,
|
||||
struct htp_spad * src0_spad,
|
||||
struct htp_spad * src1_spad,
|
||||
struct htp_spad * dst_spad,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread,
|
||||
dma_queue * dma_queue) {
|
||||
static void glu_swiglu_oai_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_act_context * actx = (struct htp_act_context *) data;
|
||||
const struct htp_tensor * src0 = &actx->octx->src0;
|
||||
const struct htp_tensor * src1 = &actx->octx->src1;
|
||||
const struct htp_tensor * dst = &actx->octx->dst;
|
||||
htp_act_preamble3;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
size_t src0_row_size = nb01;
|
||||
size_t src1_row_size = nb11;
|
||||
size_t dst_row_size = nb1;
|
||||
size_t src0_row_size = actx->src0_row_size;
|
||||
size_t src1_row_size = actx->src1_row_size;
|
||||
size_t dst_row_size = actx->dst_row_size;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
const uint32_t src0_nrows = actx->src0_nrows;
|
||||
const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
|
|
@ -226,45 +230,36 @@ static void glu_swiglu_oai_f32_per_thread(const struct htp_tensor * src0,
|
|||
return;
|
||||
}
|
||||
|
||||
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
|
||||
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
|
||||
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||
const uint8_t * restrict data_src0 = actx->data_src0;
|
||||
const uint8_t * restrict data_src1 = actx->data_src1;
|
||||
uint8_t * restrict data_dst = actx->data_dst;
|
||||
|
||||
const bool src1_valid = src1->ne[0];
|
||||
const int nc = (src1_valid) ? ne00 : ne00 / 2;
|
||||
if (!src1_valid) {
|
||||
const int32_t swapped = op_params[1];
|
||||
data_src1 = data_src0;
|
||||
src1_row_size = src0_row_size;
|
||||
const int nc = actx->nc;
|
||||
|
||||
const size_t nc_in_bytes = nc * SIZEOF_FP32;
|
||||
data_src0 += swapped ? nc_in_bytes : 0;
|
||||
data_src1 += swapped ? 0 : nc_in_bytes;
|
||||
}
|
||||
const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
|
||||
const size_t src1_row_size_aligned = actx->src1_row_size_aligned;
|
||||
const size_t dst_row_size_aligned = actx->dst_row_size_aligned;
|
||||
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
|
||||
uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread);
|
||||
uint8_t * restrict dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
|
||||
|
||||
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
|
||||
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
|
||||
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
|
||||
size_t src0_spad_half_size = actx->src0_spad_half_size;
|
||||
size_t src1_spad_half_size = actx->src1_spad_half_size;
|
||||
size_t dst_spad_half_size = actx->dst_spad_half_size;
|
||||
|
||||
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
|
||||
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
|
||||
size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
|
||||
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
|
||||
|
||||
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
|
||||
const int BLOCK = actx->block;
|
||||
if (BLOCK == 0) {
|
||||
FARF(ERROR,
|
||||
"swiglu-oai-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least "
|
||||
"%zu\n",
|
||||
src0_spad->size_per_thread, src0_row_size_aligned);
|
||||
actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
|
||||
return;
|
||||
}
|
||||
const float alpha = ((const float *) (op_params))[2];
|
||||
const float limit = ((const float *) (op_params))[3];
|
||||
const float alpha = ((const float *) (actx->octx->op_params))[2];
|
||||
const float limit = ((const float *) (actx->octx->op_params))[3];
|
||||
|
||||
dma_queue * dma_queue = actx->octx->ctx->dma[ith];
|
||||
|
||||
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
|
||||
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
|
||||
|
|
@ -335,26 +330,22 @@ static void glu_swiglu_oai_f32_per_thread(const struct htp_tensor * src0,
|
|||
}
|
||||
|
||||
|
||||
static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
|
||||
struct htp_tensor * dst,
|
||||
const int32_t * op_params,
|
||||
struct htp_spad * src0_spad,
|
||||
struct htp_spad * dst_spad,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread,
|
||||
dma_queue * dma_queue) {
|
||||
static void unary_gelu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_act_context * actx = (struct htp_act_context *) data;
|
||||
const struct htp_tensor * src0 = &actx->octx->src0;
|
||||
const struct htp_tensor * dst = &actx->octx->dst;
|
||||
htp_act_preamble2;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
const size_t src0_row_size = nb01;
|
||||
const size_t dst_row_size = nb1;
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
const size_t src0_row_size = actx->src0_row_size;
|
||||
const size_t dst_row_size = actx->dst_row_size;
|
||||
const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
|
||||
const size_t dst_row_size_aligned = actx->dst_row_size_aligned;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03;
|
||||
const uint32_t src0_nrows = actx->src0_nrows;
|
||||
const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
|
|
@ -364,25 +355,29 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
|
|||
return;
|
||||
}
|
||||
|
||||
const uint8_t * data_src0 = (const uint8_t *) src0->data;
|
||||
uint8_t * data_dst = (uint8_t *) dst->data;
|
||||
const uint8_t * data_src0 = actx->data_src0;
|
||||
uint8_t * data_dst = actx->data_dst;
|
||||
|
||||
uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
|
||||
uint8_t * dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
|
||||
// nc/ne0 matches.
|
||||
const int ne0_val = actx->nc; // == dst->ne[0]
|
||||
|
||||
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
|
||||
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
|
||||
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
|
||||
uint8_t * src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
|
||||
uint8_t * dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
|
||||
|
||||
size_t src0_spad_half_size = actx->src0_spad_half_size;
|
||||
size_t dst_spad_half_size = actx->dst_spad_half_size;
|
||||
|
||||
// In gelu = x*sigmoid(x*1.702)
|
||||
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
|
||||
const int BLOCK = actx->block;
|
||||
|
||||
if (BLOCK == 0) {
|
||||
FARF(ERROR, "gelu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
|
||||
src0_spad->size_per_thread, src0_row_size_aligned);
|
||||
actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
|
||||
return;
|
||||
}
|
||||
|
||||
dma_queue * dma_queue = actx->octx->ctx->dma[ith];
|
||||
|
||||
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
|
||||
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
|
|
@ -408,9 +403,9 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
|
|||
float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
|
||||
|
||||
// gelu = x * sigmoid(1.702 * x) // current implementation
|
||||
hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0);
|
||||
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
|
||||
hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
|
||||
hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0_val);
|
||||
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val);
|
||||
hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val);
|
||||
}
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
|
|
@ -435,34 +430,23 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
|
|||
ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void unary_gelu_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
unary_gelu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
|
||||
octx->src0_nrows_per_thread, octx->ctx->dma[i]);
|
||||
}
|
||||
|
||||
|
||||
|
||||
static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
|
||||
struct htp_tensor * dst,
|
||||
const int32_t * op_params,
|
||||
struct htp_spad * src0_spad,
|
||||
struct htp_spad * dst_spad,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread,
|
||||
dma_queue * dma_queue) {
|
||||
static void unary_silu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_act_context * actx = (struct htp_act_context *) data;
|
||||
const struct htp_tensor * src0 = &actx->octx->src0;
|
||||
const struct htp_tensor * dst = &actx->octx->dst;
|
||||
htp_act_preamble2;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
const size_t src0_row_size = nb01;
|
||||
const size_t dst_row_size = nb1;
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
const size_t src0_row_size = actx->src0_row_size;
|
||||
const size_t dst_row_size = actx->dst_row_size;
|
||||
const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
|
||||
const size_t dst_row_size_aligned = actx->dst_row_size_aligned;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03;
|
||||
const uint32_t src0_nrows = actx->src0_nrows;
|
||||
const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
|
|
@ -472,24 +456,27 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
|
|||
return;
|
||||
}
|
||||
|
||||
const uint8_t * data_src0 = (const uint8_t *) src0->data;
|
||||
uint8_t * data_dst = (uint8_t *) dst->data;
|
||||
const uint8_t * data_src0 = actx->data_src0;
|
||||
uint8_t * data_dst = actx->data_dst;
|
||||
|
||||
uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
|
||||
uint8_t * dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
|
||||
const int ne0_val = actx->nc; // == dst->ne[0]
|
||||
|
||||
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
|
||||
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
|
||||
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
|
||||
uint8_t * src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
|
||||
uint8_t * dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
|
||||
|
||||
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
|
||||
size_t src0_spad_half_size = actx->src0_spad_half_size;
|
||||
size_t dst_spad_half_size = actx->dst_spad_half_size;
|
||||
|
||||
const int BLOCK = actx->block;
|
||||
|
||||
if (BLOCK == 0) {
|
||||
FARF(ERROR, "silu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
|
||||
src0_spad->size_per_thread, src0_row_size_aligned);
|
||||
actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
|
||||
return;
|
||||
}
|
||||
|
||||
dma_queue * dma_queue = actx->octx->ctx->dma[ith];
|
||||
|
||||
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
|
||||
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
|
|
@ -515,8 +502,8 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
|
|||
float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
|
||||
|
||||
// silu = x * sigmoid(x)
|
||||
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0);
|
||||
hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
|
||||
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0_val);
|
||||
hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val);
|
||||
}
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
|
|
@ -544,27 +531,22 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
|
|||
static const float GELU_COEF_A = 0.044715f;
|
||||
static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
|
||||
static void glu_geglu_f32_per_thread(const struct htp_tensor * src0,
|
||||
const struct htp_tensor * src1,
|
||||
struct htp_tensor * dst,
|
||||
const int32_t * op_params,
|
||||
struct htp_spad * src0_spad,
|
||||
struct htp_spad * src1_spad,
|
||||
struct htp_spad * dst_spad,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread,
|
||||
dma_queue * dma_queue) {
|
||||
static void glu_geglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_act_context * actx = (struct htp_act_context *) data;
|
||||
const struct htp_tensor * src0 = &actx->octx->src0;
|
||||
const struct htp_tensor * src1 = &actx->octx->src1;
|
||||
const struct htp_tensor * dst = &actx->octx->dst;
|
||||
htp_act_preamble3;
|
||||
|
||||
size_t src0_row_size = nb01;
|
||||
size_t src1_row_size = nb11;
|
||||
size_t dst_row_size = nb1;
|
||||
size_t src0_row_size = actx->src0_row_size;
|
||||
size_t src1_row_size = actx->src1_row_size;
|
||||
size_t dst_row_size = actx->dst_row_size;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
const uint32_t src0_nrows = actx->src0_nrows;
|
||||
const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
|
|
@ -574,43 +556,34 @@ static void glu_geglu_f32_per_thread(const struct htp_tensor * src0,
|
|||
return;
|
||||
}
|
||||
|
||||
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
|
||||
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
|
||||
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||
const uint8_t * restrict data_src0 = actx->data_src0;
|
||||
const uint8_t * restrict data_src1 = actx->data_src1;
|
||||
uint8_t * restrict data_dst = actx->data_dst;
|
||||
|
||||
const bool src1_valid = src1->ne[0];
|
||||
const int nc = (src1_valid) ? ne00 : ne00 / 2;
|
||||
if (!src1_valid) {
|
||||
const int32_t swapped = op_params[1];
|
||||
data_src1 = data_src0;
|
||||
src1_row_size = src0_row_size;
|
||||
const int nc = actx->nc;
|
||||
|
||||
const size_t nc_in_bytes = nc * SIZEOF_FP32;
|
||||
data_src0 += swapped ? nc_in_bytes : 0;
|
||||
data_src1 += swapped ? 0 : nc_in_bytes;
|
||||
}
|
||||
const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
|
||||
const size_t src1_row_size_aligned = actx->src1_row_size_aligned;
|
||||
const size_t dst_row_size_aligned = actx->dst_row_size_aligned;
|
||||
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
|
||||
uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread);
|
||||
uint8_t * restrict dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
|
||||
|
||||
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
|
||||
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
|
||||
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
|
||||
size_t src0_spad_half_size = actx->src0_spad_half_size;
|
||||
size_t src1_spad_half_size = actx->src1_spad_half_size;
|
||||
size_t dst_spad_half_size = actx->dst_spad_half_size;
|
||||
|
||||
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
|
||||
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
|
||||
size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
|
||||
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
|
||||
|
||||
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
|
||||
const int BLOCK = actx->block;
|
||||
if (BLOCK == 0) {
|
||||
FARF(ERROR,
|
||||
"geglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
|
||||
src0_spad->size_per_thread, src0_row_size_aligned);
|
||||
actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
|
||||
return;
|
||||
}
|
||||
|
||||
dma_queue * dma_queue = actx->octx->ctx->dma[ith];
|
||||
|
||||
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
|
||||
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
|
|
@ -678,33 +651,7 @@ static void glu_geglu_f32_per_thread(const struct htp_tensor * src0,
|
|||
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void unary_silu_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
unary_silu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
|
||||
octx->src0_nrows_per_thread, octx->ctx->dma[i]);
|
||||
}
|
||||
|
||||
static void glu_swiglu_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
glu_swiglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
|
||||
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
|
||||
}
|
||||
|
||||
static void glu_swiglu_oai_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
glu_swiglu_oai_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
|
||||
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
|
||||
}
|
||||
|
||||
static void glu_geglu_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
glu_geglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
|
||||
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
|
||||
}
|
||||
|
||||
static int execute_op_activations_f32(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
|
@ -719,26 +666,26 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
|
|||
|
||||
switch (octx->op) {
|
||||
case HTP_OP_UNARY_SILU:
|
||||
act_op_func = unary_silu_f32;
|
||||
act_op_func = (worker_callback_t)unary_silu_f32_per_thread;
|
||||
op_type = "silu-f32";
|
||||
break;
|
||||
|
||||
case HTP_OP_GLU_SWIGLU:
|
||||
act_op_func = glu_swiglu_f32;
|
||||
act_op_func = (worker_callback_t)glu_swiglu_f32_per_thread;
|
||||
op_type = "swiglu-f32";
|
||||
break;
|
||||
|
||||
case HTP_OP_GLU_SWIGLU_OAI:
|
||||
act_op_func = glu_swiglu_oai_f32;
|
||||
act_op_func = (worker_callback_t)glu_swiglu_oai_f32_per_thread;
|
||||
op_type = "swiglu-oai-f32";
|
||||
break;
|
||||
case HTP_OP_UNARY_GELU:
|
||||
act_op_func = unary_gelu_f32;
|
||||
act_op_func = (worker_callback_t)unary_gelu_f32_per_thread;
|
||||
op_type = "gelu-f32";
|
||||
break;
|
||||
|
||||
case HTP_OP_GLU_GEGLU:
|
||||
act_op_func = glu_geglu_f32;
|
||||
act_op_func = (worker_callback_t)glu_geglu_f32_per_thread;
|
||||
op_type = "geglu-f32";
|
||||
break;
|
||||
default:
|
||||
|
|
@ -797,13 +744,58 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
|
|||
octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
|
||||
}
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, act_op_func, octx, n_jobs);
|
||||
if ((octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
return err;
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
|
||||
// Prepare context
|
||||
struct htp_act_context actx;
|
||||
actx.octx = octx;
|
||||
|
||||
actx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
|
||||
actx.src0_row_size = src0_row_size;
|
||||
actx.src1_row_size = src1_row_size;
|
||||
actx.dst_row_size = dst_row_size;
|
||||
|
||||
actx.src0_row_size_aligned = src0_row_size_aligned;
|
||||
actx.src1_row_size_aligned = src1_row_size_aligned;
|
||||
actx.dst_row_size_aligned = dst_row_size_aligned;
|
||||
|
||||
actx.src0_spad_half_size = octx->src0_spad.size_per_thread / 2;
|
||||
actx.src1_spad_half_size = octx->src1_spad.size_per_thread / 2;
|
||||
actx.dst_spad_half_size = octx->dst_spad.size_per_thread / 2;
|
||||
|
||||
actx.block = actx.src0_spad_half_size / actx.src0_row_size_aligned;
|
||||
actx.src0_nrows = src0_nrows;
|
||||
|
||||
actx.nc = dst->ne[0];
|
||||
|
||||
// Pointers and GLU logic
|
||||
const uint8_t * data_src0 = (const uint8_t *) src0->data;
|
||||
const uint8_t * data_src1 = (const uint8_t *) src1->data;
|
||||
|
||||
if (!src1_valid && (octx->op == HTP_OP_GLU_SWIGLU || octx->op == HTP_OP_GLU_SWIGLU_OAI || octx->op == HTP_OP_GLU_GEGLU)) {
|
||||
const int32_t swapped = octx->op_params[1];
|
||||
data_src1 = data_src0;
|
||||
actx.src1_row_size = actx.src0_row_size;
|
||||
|
||||
size_t nc_in_bytes = actx.nc * SIZEOF_FP32;
|
||||
if (swapped) {
|
||||
data_src0 += nc_in_bytes;
|
||||
} else {
|
||||
data_src1 += nc_in_bytes;
|
||||
}
|
||||
}
|
||||
|
||||
actx.data_src0 = data_src0;
|
||||
actx.data_src1 = data_src1;
|
||||
actx.data_dst = (uint8_t *) dst->data;
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, act_op_func, &actx, n_jobs);
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
int op_activations(struct htp_ops_context * octx) {
|
||||
|
|
|
|||
|
|
@ -15,6 +15,13 @@
|
|||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
|
||||
struct get_rows_context {
|
||||
struct htp_ops_context * octx;
|
||||
uint32_t src1_nrows_per_thread;
|
||||
struct fastdiv_values get_rows_div_ne10;
|
||||
struct fastdiv_values get_rows_div_ne10_ne11;
|
||||
};
|
||||
|
||||
#define get_rows_preamble \
|
||||
const uint32_t ne00 = octx->src0.ne[0]; \
|
||||
const uint32_t ne01 = octx->src0.ne[1]; \
|
||||
|
|
@ -39,20 +46,22 @@
|
|||
\
|
||||
const uint32_t nr = ne10 * ne11 * ne12;
|
||||
|
||||
static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
|
||||
static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) {
|
||||
struct get_rows_context * grctx = (struct get_rows_context *)data;
|
||||
struct htp_ops_context * octx = grctx->octx;
|
||||
get_rows_preamble;
|
||||
|
||||
// parallelize by src1 elements (which correspond to dst rows)
|
||||
const uint32_t dr = octx->src1_nrows_per_thread;
|
||||
const uint32_t dr = grctx->src1_nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
|
||||
|
||||
const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
|
||||
|
||||
for (uint32_t i = ir0; i < ir1; ++i) {
|
||||
const uint32_t i12 = fastdiv(i, &octx->get_rows_div_ne10_ne11);
|
||||
const uint32_t i12 = fastdiv(i, &grctx->get_rows_div_ne10_ne11);
|
||||
const uint32_t rem = i - i12 * ne11 * ne10;
|
||||
const uint32_t i11 = fastdiv(rem, &octx->get_rows_div_ne10);
|
||||
const uint32_t i11 = fastdiv(rem, &grctx->get_rows_div_ne10);
|
||||
const uint32_t i10 = rem - i11 * ne10;
|
||||
|
||||
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
|
||||
|
|
@ -68,12 +77,6 @@ static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth,
|
|||
const uintptr_t dst_ptr = octx->dst.data + i10*nb1 + i11*nb2 + i12*nb3;
|
||||
hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
static void get_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
|
||||
get_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
|
||||
}
|
||||
|
||||
int op_get_rows(struct htp_ops_context * octx) {
|
||||
|
|
@ -95,12 +98,14 @@ int op_get_rows(struct htp_ops_context * octx) {
|
|||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
octx->get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]);
|
||||
octx->get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]);
|
||||
struct get_rows_context grctx;
|
||||
grctx.octx = octx;
|
||||
grctx.get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]);
|
||||
grctx.get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]);
|
||||
|
||||
const uint32_t n_jobs = MIN(nr, octx->n_threads);
|
||||
octx->src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
|
||||
grctx.src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, get_rows_work_f32_f32, octx, n_jobs);
|
||||
worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32, &grctx, n_jobs);
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -102,7 +102,7 @@ static inline bool dma_queue_push(dma_queue * q,
|
|||
dmlink(q->tail, desc);
|
||||
q->tail = desc;
|
||||
|
||||
// FARF(ERROR, "dma-push: i %u len %u dst %p src %p\n", q->push_idx, len, dst, src);
|
||||
// FARF(ERROR, "dma-push: i %u width %u nrows %d dst %p src %p\n", q->push_idx, width, nrows, dptr.dst, dptr.src);
|
||||
q->push_idx = (q->push_idx + 1) & q->idx_mask;
|
||||
return true;
|
||||
}
|
||||
|
|
@ -144,11 +144,37 @@ static inline dma_ptr dma_queue_pop(dma_queue * q) {
|
|||
|
||||
dptr = q->dptr[q->pop_idx];
|
||||
|
||||
// FARF(ERROR, "dma-pop: i %u dst %p\n", q->pop_idx, dst);
|
||||
// FARF(ERROR, "dma-pop: i %u dst %p src %p\n", q->pop_idx, dptr.dst, dptr.src);
|
||||
q->pop_idx = (q->pop_idx + 1) & q->idx_mask;
|
||||
return dptr;
|
||||
}
|
||||
|
||||
static inline dma_ptr dma_queue_pop_nowait(dma_queue * q) {
|
||||
dma_ptr dptr = { NULL };
|
||||
|
||||
if (q->push_idx == q->pop_idx) {
|
||||
return dptr;
|
||||
}
|
||||
|
||||
dptr = q->dptr[q->pop_idx];
|
||||
|
||||
// FARF(ERROR, "dma-pop-nowait: i %u dst %p src %p\n", q->pop_idx, dptr.dst, dptr.src);
|
||||
q->pop_idx = (q->pop_idx + 1) & q->idx_mask;
|
||||
return dptr;
|
||||
}
|
||||
|
||||
static inline bool dma_queue_empty(dma_queue * q) {
|
||||
return q->push_idx == q->pop_idx;
|
||||
}
|
||||
|
||||
static inline uint32_t dma_queue_depth(dma_queue * q) {
|
||||
return (q->push_idx - q->pop_idx) & q->idx_mask;
|
||||
}
|
||||
|
||||
static inline uint32_t dma_queue_capacity(dma_queue * q) {
|
||||
return q->capacity;
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -44,32 +44,6 @@ struct htp_ops_context {
|
|||
uint32_t src0_nrows_per_thread;
|
||||
uint32_t src1_nrows_per_thread;
|
||||
|
||||
struct fastdiv_values src0_div1; // fastdiv values for ne1
|
||||
struct fastdiv_values src0_div2; // fastdiv values for ne2
|
||||
struct fastdiv_values src0_div3; // fastdiv values for ne3
|
||||
struct fastdiv_values src0_div21; // fastdiv values for ne2 * ne1
|
||||
|
||||
struct fastdiv_values src1_div1; // fastdiv values for ne1
|
||||
struct fastdiv_values src1_div2; // fastdiv values for ne2
|
||||
struct fastdiv_values src1_div3; // fastdiv values for ne3
|
||||
struct fastdiv_values src1_div21; // fastdiv values for ne2 * ne1
|
||||
|
||||
struct fastdiv_values src3_div1; // fastdiv values for ne1
|
||||
struct fastdiv_values src3_div2; // fastdiv values for ne2
|
||||
struct fastdiv_values src3_div3; // fastdiv values for ne3
|
||||
struct fastdiv_values src3_div21; // fastdiv values for ne2 * ne1
|
||||
|
||||
struct fastdiv_values broadcast_rk2;
|
||||
struct fastdiv_values broadcast_rk3;
|
||||
struct fastdiv_values broadcast_rv2;
|
||||
struct fastdiv_values broadcast_rv3;
|
||||
|
||||
struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12
|
||||
struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11
|
||||
|
||||
struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10
|
||||
struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11
|
||||
|
||||
uint32_t flags;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -49,62 +49,6 @@ struct htp_matmul_context {
|
|||
struct fastdiv_values mm_div_r3;
|
||||
};
|
||||
|
||||
// vdelta control to replicate first 4x fp32 values across lanes
|
||||
static const uint8_t __attribute__((aligned(128))) repl_4x_f32[128] = {
|
||||
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10,
|
||||
0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20,
|
||||
0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04,
|
||||
0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40,
|
||||
0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04,
|
||||
0x04, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
|
||||
0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10,
|
||||
};
|
||||
|
||||
// vdelta control to replicate and interleave first 8x fp32 values across lanes
|
||||
static const uint8_t __attribute__((aligned(128))) repl_interleave_8x_f32[128] = {
|
||||
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x00, 0x00, 0x00,
|
||||
0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20,
|
||||
0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04,
|
||||
0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40,
|
||||
0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x44, 0x44, 0x44,
|
||||
0x44, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
|
||||
0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20,
|
||||
};
|
||||
|
||||
// vdelta control to replicate first fp32 value across all elements
|
||||
static const uint8_t __attribute__((aligned(128))) repl_1x_f32[128] = {
|
||||
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10,
|
||||
0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
|
||||
0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08,
|
||||
0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08,
|
||||
0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04,
|
||||
0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10,
|
||||
0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||
};
|
||||
|
||||
// vdelta control to replicate first fp16 value across all elements
|
||||
static const uint8_t __attribute__((aligned(128))) repl_1x_f16[128] = {
|
||||
0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02,
|
||||
0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04,
|
||||
0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08,
|
||||
0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x40, 0x40, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02,
|
||||
0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02,
|
||||
0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10,
|
||||
0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
};
|
||||
|
||||
// vdelta control to replicate first fp16 value across all elements
|
||||
static const uint8_t __attribute__((aligned(128))) repl_2x_f16[128] = {
|
||||
0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
};
|
||||
|
||||
// vdelta control to expand first 32 e8m0 values into 32 uint32 elements
|
||||
static const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = {
|
||||
0x00, 0x00, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00, 0x00,
|
||||
|
|
@ -2067,10 +2011,10 @@ static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restric
|
|||
HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements
|
||||
|
||||
// Convert to QF32
|
||||
HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero);
|
||||
HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero);
|
||||
HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero);
|
||||
HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero);
|
||||
HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); // replicated over all lanes
|
||||
HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); // replicated over all lanes
|
||||
HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); // replicated over all lanes
|
||||
HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); // replicated over all lanes
|
||||
|
||||
// Combine and convert to fp16
|
||||
HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf)));
|
||||
|
|
@ -2080,11 +2024,6 @@ static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restric
|
|||
HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
|
||||
HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
|
||||
|
||||
// Replicate first fp16 scale across all lanes
|
||||
HVX_Vector ctrl = *(const HVX_Vector *) repl_2x_f16;
|
||||
vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl);
|
||||
vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl);
|
||||
|
||||
HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
|
||||
HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
|
||||
HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16);
|
||||
|
|
@ -2130,13 +2069,8 @@ static inline void quantize_block_f32_q8x2(float * restrict x, uint8_t * restric
|
|||
HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
|
||||
|
||||
// Compute max and scale
|
||||
HVX_Vector vmax01_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf));
|
||||
HVX_Vector vmax23_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx23_hf));
|
||||
|
||||
// Replicate first fp16 scale across all lanes
|
||||
HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_f16;
|
||||
vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl);
|
||||
vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl);
|
||||
HVX_Vector vmax01_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf)); // replicated over all lanes
|
||||
HVX_Vector vmax23_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx23_hf)); // replicated over all lanes
|
||||
|
||||
HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
|
||||
HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
|
||||
|
|
@ -2179,11 +2113,7 @@ static inline void quantize_block_f32_q8x4(float * restrict x, uint8_t * restric
|
|||
|
||||
// Compute max and scale
|
||||
HVX_Vector vmax_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf));
|
||||
vmax_hf = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf);
|
||||
|
||||
// Replicate first fp16 scale across all lanes
|
||||
HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_f16;
|
||||
vmax_hf = Q6_V_vdelta_VV(vmax_hf, ctrl);
|
||||
vmax_hf = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf); // replicated over all lanes
|
||||
|
||||
HVX_Vector vd_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
|
||||
HVX_Vector vd_hf = Q6_Vhf_equals_Vqf16(vd_qf16);
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
|
||||
#include "hex-dma.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "hex-fastdiv.h"
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
|
|
@ -21,6 +22,9 @@
|
|||
#define HTP_ROPE_TYPE_NORMAL 0
|
||||
#define HTP_ROPE_TYPE_NEOX 2
|
||||
|
||||
#define HTP_ROPE_SPAD_NROWS 16
|
||||
#define HTP_ROPE_SPAD_BLOCK (HTP_ROPE_SPAD_NROWS/2)
|
||||
|
||||
#define htp_rope_preamble \
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
|
|
@ -42,7 +46,7 @@
|
|||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
struct rope_th_ctx {
|
||||
struct htp_rope_context {
|
||||
int32_t n_dims;
|
||||
int32_t mode;
|
||||
int32_t n_ctx_orig;
|
||||
|
|
@ -57,7 +61,19 @@ struct rope_th_ctx {
|
|||
float theta_scale;
|
||||
float corr_dims[2];
|
||||
|
||||
uint32_t src0_nrows_per_thread;
|
||||
size_t spad_stride;
|
||||
|
||||
struct htp_ops_context * octx;
|
||||
|
||||
size_t src0_row_size;
|
||||
size_t dst_row_size;
|
||||
size_t src0_row_size_aligned;
|
||||
size_t dst_row_size_aligned;
|
||||
size_t theta_cache_offset;
|
||||
uint32_t src0_nrows;
|
||||
|
||||
uint64_t t_start;
|
||||
};
|
||||
|
||||
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
||||
|
|
@ -117,64 +133,23 @@ static void rope_corr_dims(int n_dims,
|
|||
dims[1] = MIN(n_dims - 1, end);
|
||||
}
|
||||
|
||||
static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context * octx) {
|
||||
memset(rope_ctx, 0, sizeof(struct rope_th_ctx));
|
||||
static inline void hvx_rope_neox_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) {
|
||||
const HVX_Vector * restrict vsrc = (const HVX_Vector *) src0;
|
||||
const HVX_Vector * restrict vtheta = (const HVX_Vector *) theta_cache;
|
||||
HVX_Vector * restrict vdst = (HVX_Vector *) dst;
|
||||
|
||||
const int32_t * op_params = &octx->op_params[0];
|
||||
uint32_t nvec = (ne / (VLEN_FP32 * 2) * 2); // 2 vecs per loop, step of 2
|
||||
|
||||
rope_ctx->n_dims = ((const int32_t *) op_params)[1];
|
||||
rope_ctx->mode = ((const int32_t *) op_params)[2];
|
||||
rope_ctx->n_ctx_orig = ((const int32_t *) op_params)[4];
|
||||
uint32_t he = ne / 2; // half_dims offset in elements
|
||||
uint32_t hv = he / VLEN_FP32; // half_dims offset in vectors
|
||||
|
||||
memcpy(&rope_ctx->freq_base, (int32_t *) op_params + 5, sizeof(float));
|
||||
memcpy(&rope_ctx->freq_scale, (int32_t *) op_params + 6, sizeof(float));
|
||||
memcpy(&rope_ctx->ext_factor, (int32_t *) op_params + 7, sizeof(float));
|
||||
memcpy(&rope_ctx->attn_factor, (int32_t *) op_params + 8, sizeof(float));
|
||||
memcpy(&rope_ctx->beta_fast, (int32_t *) op_params + 9, sizeof(float));
|
||||
memcpy(&rope_ctx->beta_slow, (int32_t *) op_params + 10, sizeof(float));
|
||||
memcpy(&rope_ctx->sections, (int32_t *) op_params + 11, sizeof(int) * 4);
|
||||
#pragma unroll(2)
|
||||
for (uint32_t i = 0; i < nvec; i += 2) {
|
||||
HVX_Vector v0 = vsrc[i/2+0];
|
||||
HVX_Vector v1 = vsrc[i/2+hv];
|
||||
|
||||
rope_ctx->theta_scale = powf(rope_ctx->freq_base, -2.0f / rope_ctx->n_dims);
|
||||
|
||||
rope_corr_dims(rope_ctx->n_dims, rope_ctx->n_ctx_orig, rope_ctx->freq_base, rope_ctx->beta_fast,
|
||||
rope_ctx->beta_slow, rope_ctx->corr_dims);
|
||||
|
||||
rope_ctx->octx = octx;
|
||||
FARF(HIGH, "rope-f32 n_dims:%d, ext_factor:%.6f, theta_scale:%.6f, attn_factor:%.6f\n", rope_ctx->n_dims,
|
||||
rope_ctx->ext_factor, rope_ctx->theta_scale, rope_ctx->attn_factor);
|
||||
}
|
||||
|
||||
static void hvx_calc_rope_neox_f32(const float * restrict src0,
|
||||
float * restrict dst,
|
||||
const int num_elems,
|
||||
const float * restrict theta_cache) {
|
||||
// for (int i = 0; i < num_elems; i += 2) {
|
||||
//const float cos_theta = theta_cache[i + 0];
|
||||
//const float sin_theta = theta_cache[i + 1];
|
||||
|
||||
//const float x0 = src[0];
|
||||
//const float x1 = src[num_elems/2];
|
||||
|
||||
//dst[0] = x0*cos_theta - x1*sin_theta;
|
||||
//dst[num_elems/2] = x0*sin_theta + x1*cos_theta;
|
||||
|
||||
//src += 1;
|
||||
//dst += 1;
|
||||
// }
|
||||
|
||||
const uint8_t * restrict src0_curr = (const uint8_t *) src0;
|
||||
const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache;
|
||||
uint8_t * restrict dst_curr = (uint8_t *) dst;
|
||||
|
||||
int step_of_1 = num_elems >> 6; // 6 because we process two vectors at once
|
||||
int half_size = (sizeof(float) * (num_elems / 2));
|
||||
|
||||
for (int i = 0; i < step_of_1; i++) {
|
||||
HVX_Vector v0 = *(HVX_Vector *) src0_curr;
|
||||
HVX_Vector v1 = *(HVX_Vector *) (src0_curr + half_size);
|
||||
|
||||
HVX_Vector v2 = *(HVX_Vector *) theta_curr;
|
||||
HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN);
|
||||
HVX_Vector v2 = vtheta[i+0];
|
||||
HVX_Vector v3 = vtheta[i+1];
|
||||
|
||||
HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta
|
||||
|
||||
|
|
@ -186,45 +161,34 @@ static void hvx_calc_rope_neox_f32(const float * restrict src0,
|
|||
HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s);
|
||||
HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c);
|
||||
|
||||
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v4);
|
||||
*(HVX_Vector *) (dst_curr + half_size) = Q6_Vsf_equals_Vqf32(v5);
|
||||
vdst[i/2+0] = Q6_Vsf_equals_Vqf32(v4);
|
||||
vdst[i/2+hv] = Q6_Vsf_equals_Vqf32(v5);
|
||||
}
|
||||
|
||||
src0_curr += VLEN;
|
||||
theta_curr += 2 * VLEN;
|
||||
dst_curr += VLEN;
|
||||
for (uint32_t i = nvec * VLEN_FP32; i < ne; i += 2) {
|
||||
const float cos_theta = theta_cache[i+0];
|
||||
const float sin_theta = theta_cache[i+1];
|
||||
float x0 = src0[i/2];
|
||||
float x1 = src0[i/2 + he];
|
||||
dst[i/2] = x0 * cos_theta - x1 * sin_theta;
|
||||
dst[i/2 + he] = x0 * sin_theta + x1 * cos_theta;
|
||||
}
|
||||
}
|
||||
|
||||
static void hvx_calc_rope_f32(const float * restrict src0,
|
||||
float * restrict dst,
|
||||
const int num_elems,
|
||||
const float * restrict theta_cache) {
|
||||
// for (int i = 0; i < num_elems; i += 2) {
|
||||
//const float cos_theta = theta_cache[i + 0];
|
||||
//const float sin_theta = theta_cache[i + 1];
|
||||
static inline void hvx_rope_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) {
|
||||
const HVX_Vector * restrict vsrc = (const HVX_Vector *) src0;
|
||||
const HVX_Vector * restrict vtheta = (const HVX_Vector *) theta_cache;
|
||||
HVX_Vector * restrict vdst = (HVX_Vector *) dst;
|
||||
|
||||
//const float x0 = src[0];
|
||||
//const float x1 = src[1];
|
||||
uint32_t nvec = (ne / (VLEN_FP32 * 2)) * 2; // 2 vecs per loop, step of two
|
||||
|
||||
//dst[0] = x0*cos_theta - x1*sin_theta;
|
||||
//dst[1] = x0*sin_theta + x1*cos_theta;
|
||||
#pragma unroll(2)
|
||||
for (uint32_t i = 0; i < nvec; i+=2) {
|
||||
HVX_Vector v0 = vsrc[i+0];
|
||||
HVX_Vector v1 = vsrc[i+1];
|
||||
|
||||
//src += 2;
|
||||
//dst += 2;
|
||||
// }
|
||||
|
||||
const uint8_t * restrict src0_curr = (const uint8_t *) src0;
|
||||
const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache;
|
||||
uint8_t * restrict dst_curr = (uint8_t *) dst;
|
||||
|
||||
int step_of_1 = num_elems >> 6; // 6 because we process two vectors at once
|
||||
|
||||
for (int i = 0; i < step_of_1; i++) {
|
||||
HVX_Vector v0 = *(HVX_Vector *) src0_curr;
|
||||
HVX_Vector v1 = *(HVX_Vector *) (src0_curr + VLEN);
|
||||
|
||||
HVX_Vector v2 = *(HVX_Vector *) theta_curr;
|
||||
HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN);
|
||||
HVX_Vector v2 = vtheta[i+0];
|
||||
HVX_Vector v3 = vtheta[i+1];
|
||||
|
||||
HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(v1, v0, -4); // vx0_x1[0] = x0, vx0_x1[1] = x1
|
||||
HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta
|
||||
|
|
@ -239,116 +203,65 @@ static void hvx_calc_rope_f32(const float * restrict src0,
|
|||
|
||||
HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4);
|
||||
|
||||
*(HVX_Vector *) dst_curr = Q6_V_lo_W(vstore);
|
||||
*(HVX_Vector *) (dst_curr + VLEN) = Q6_V_hi_W(vstore);
|
||||
vdst[i+0] = Q6_V_lo_W(vstore);
|
||||
vdst[i+1] = Q6_V_hi_W(vstore);
|
||||
}
|
||||
|
||||
src0_curr += 2 * VLEN;
|
||||
theta_curr += 2 * VLEN;
|
||||
dst_curr += 2 * VLEN;
|
||||
for (uint32_t i = nvec * VLEN_FP32; i < ne; i += 2) {
|
||||
const float cos_theta = theta_cache[i+0];
|
||||
const float sin_theta = theta_cache[i+1];
|
||||
float x0 = src0[i+0];
|
||||
float x1 = src0[i+1];
|
||||
dst[i+0] = x0 * cos_theta - x1 * sin_theta;
|
||||
dst[i+1] = x0 * sin_theta + x1 * cos_theta;
|
||||
}
|
||||
}
|
||||
|
||||
static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
|
||||
const uint32_t ir0,
|
||||
const uint32_t ir1,
|
||||
int nth,
|
||||
int ith,
|
||||
const int opt_path) {
|
||||
struct htp_ops_context * octx = rope_ctx->octx;
|
||||
static void inline rope_basic_f32(struct htp_rope_context * rctx, uint8_t * restrict dst, uint8_t * restrict src,
|
||||
uint32_t nr, uint32_t ne0, const float * restrict theta_cache) {
|
||||
#pragma unroll(4)
|
||||
for (uint32_t i = 0; i < nr; i++) {
|
||||
float * d = (float *) (dst + i * rctx->dst_row_size_aligned);
|
||||
float * s = (float *) (src + i * rctx->src0_row_size_aligned);
|
||||
|
||||
hvx_rope_f32_aa(d, s, rctx->n_dims, theta_cache);
|
||||
|
||||
// fill the remain channels with data from src tensor
|
||||
if (rctx->n_dims < ne0) {
|
||||
hvx_copy_f32_uu((uint8_t *)(d + rctx->n_dims), (uint8_t *)(s + rctx->n_dims), ne0 - rctx->n_dims);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void inline rope_neox_f32(struct htp_rope_context * rctx, uint8_t * restrict dst, uint8_t * restrict src,
|
||||
uint32_t nr, uint32_t ne0, const float * restrict theta_cache) {
|
||||
#pragma unroll(4)
|
||||
for (uint32_t i = 0; i < nr; i++) {
|
||||
float * d = (float *) (dst + i * rctx->dst_row_size_aligned);
|
||||
float * s = (float *) (src + i * rctx->src0_row_size_aligned);
|
||||
|
||||
hvx_rope_neox_f32_aa(d, s, rctx->n_dims, theta_cache);
|
||||
|
||||
// fill the remain channels with data from src tensor
|
||||
if (rctx->n_dims < ne0) {
|
||||
hvx_copy_f32_uu((uint8_t *)(d + rctx->n_dims), (uint8_t *)(s + rctx->n_dims), ne0 - rctx->n_dims);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_rope_context * rctx = (struct htp_rope_context *) data;
|
||||
struct htp_ops_context * octx = rctx->octx;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
const struct htp_tensor * src2 = &octx->src2;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
const int32_t mode = rope_ctx->mode;
|
||||
const bool is_neox = mode & HTP_ROPE_TYPE_NEOX;
|
||||
|
||||
htp_rope_preamble;
|
||||
|
||||
const int32_t * pos = (const int32_t *) src1->data;
|
||||
|
||||
float * wp0 = (float *) (octx->src0_spad.data + (ith * nb01));
|
||||
|
||||
const float * freq_factors = NULL;
|
||||
if (src2 != NULL) {
|
||||
freq_factors = (const float *) src2->data;
|
||||
}
|
||||
|
||||
const uint32_t i1_end = MIN(ir1, ne1);
|
||||
const int32_t half_dims = rope_ctx->n_dims / 2;
|
||||
const size_t remain_bytes = (ne0 - rope_ctx->n_dims) * sizeof(float);
|
||||
for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch
|
||||
for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len
|
||||
const int32_t p = pos[i2];
|
||||
|
||||
rope_cache_init(p, rope_ctx->freq_scale, freq_factors, rope_ctx->corr_dims, ne0, rope_ctx->ext_factor,
|
||||
rope_ctx->attn_factor, wp0, rope_ctx->theta_scale);
|
||||
|
||||
for (uint32_t i1 = ir0; i1 < i1_end; i1++) { // attn-heads
|
||||
const float * src = (float *) ((char *) src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01);
|
||||
float * dst_data = (float *) ((char *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1);
|
||||
|
||||
const float * src_loc = src;
|
||||
float * dst_data_loc = dst_data;
|
||||
|
||||
if (1 == opt_path) {
|
||||
if (is_neox) {
|
||||
hvx_calc_rope_neox_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
|
||||
} else {
|
||||
hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
|
||||
}
|
||||
|
||||
src_loc += rope_ctx->n_dims;
|
||||
dst_data_loc += rope_ctx->n_dims;
|
||||
} else {
|
||||
for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) {
|
||||
const float cos_theta = wp0[i0 + 0];
|
||||
const float sin_theta = wp0[i0 + 1];
|
||||
|
||||
if (is_neox) {
|
||||
const float x0 = src_loc[0];
|
||||
const float x1 = src_loc[half_dims];
|
||||
|
||||
dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
|
||||
dst_data_loc[half_dims] = x0 * sin_theta + x1 * cos_theta;
|
||||
|
||||
src_loc += 1;
|
||||
dst_data_loc += 1;
|
||||
} else {
|
||||
const float x0 = src_loc[0];
|
||||
const float x1 = src_loc[1];
|
||||
|
||||
dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
|
||||
dst_data_loc[1] = x0 * sin_theta + x1 * cos_theta;
|
||||
|
||||
src_loc += 2;
|
||||
dst_data_loc += 2;
|
||||
}
|
||||
}
|
||||
|
||||
src_loc += (is_neox ? half_dims : 0);
|
||||
dst_data_loc += (is_neox ? half_dims : 0);
|
||||
}
|
||||
|
||||
// TODO: use simd to speed up the remaining elements copy
|
||||
memcpy(dst_data_loc, src_loc, remain_bytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void rope_job_f32_per_thread(struct rope_th_ctx * rope_ctx, int nth, int ith) {
|
||||
struct htp_ops_context * octx = rope_ctx->octx;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
htp_rope_preamble;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
|
||||
const uint32_t src0_nrows = rctx->src0_nrows;
|
||||
const uint32_t src0_nrows_per_thread = rctx->src0_nrows_per_thread;
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
|
|
@ -358,32 +271,114 @@ static void rope_job_f32_per_thread(struct rope_th_ctx * rope_ctx, int nth, int
|
|||
return;
|
||||
}
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
uint64_t tt = HAP_perf_get_qtimer_count();
|
||||
|
||||
int is_aligned = 1;
|
||||
int opt_path = 0;
|
||||
if ((0 == hex_is_aligned((void *) src0->data, VLEN)) || (0 == hex_is_aligned((void *) src1->data, VLEN)) ||
|
||||
(0 == hex_is_aligned((void *) dst->data, VLEN))) {
|
||||
FARF(HIGH, "rope-f32: unaligned addresses in rope op, possibly slower execution\n");
|
||||
is_aligned = 0;
|
||||
}
|
||||
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
|
||||
opt_path = 1;
|
||||
const int32_t mode = rctx->mode;
|
||||
const bool is_neox = mode & HTP_ROPE_TYPE_NEOX;
|
||||
|
||||
// VTCM setup
|
||||
uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
|
||||
float * theta_cache = (float *) (src0_spad_base);
|
||||
src0_spad_base = src0_spad_base + rctx->theta_cache_offset;
|
||||
uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
|
||||
|
||||
dma_queue * dma_queue = octx->ctx->dma[ith];
|
||||
const int32_t * pos = (const int32_t *) src1->data;
|
||||
const float * freq_factors = src2->data ? (const float *) src2->data : NULL;
|
||||
|
||||
uint32_t ir = 0;
|
||||
uint32_t prev_i2 = (uint32_t) -1;
|
||||
|
||||
for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch
|
||||
for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len
|
||||
for (uint32_t i1 = 0; i1 < ne1; ) { // attn-heads
|
||||
if (ir < src0_start_row) { ir++; i1++; continue; }
|
||||
if (ir >= src0_end_row) goto done;
|
||||
|
||||
// Rows in this block
|
||||
const uint32_t nrows = MIN(src0_end_row - ir, ne1 - i1);
|
||||
|
||||
// Depth before prefetch
|
||||
uint32_t dma_depth = dma_queue_depth(dma_queue);
|
||||
|
||||
// FARF(HIGH, "rope-block %u: ir %u n-rows %u dma-depth %u : usec %u", ith, ir, nrows, dma_depth,
|
||||
// (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));
|
||||
|
||||
// Prefetch loop
|
||||
for (uint32_t pnr = 0, pr = 0; pr < nrows && pr < HTP_ROPE_SPAD_NROWS; pr += pnr) {
|
||||
pnr = MIN(nrows - pr, HTP_ROPE_SPAD_BLOCK);
|
||||
|
||||
uint32_t pi1 = i1 + pr;
|
||||
uint32_t pir = ir + pr;
|
||||
|
||||
// Dummy DMA transaction for sequencing (interleaving dst,src,dst,...)
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr((void *) dst->data, dst_spad_base + pr * rctx->dst_row_size_aligned), 0, 0, 0);
|
||||
|
||||
const uint8_t * src_addr = (const uint8_t *) src0->data + i3 * nb03 + i2 * nb02 + pi1 * nb01;
|
||||
uint8_t * src_spad = src0_spad_base + pr * rctx->src0_row_size_aligned;
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src_spad, src_addr),
|
||||
rctx->src0_row_size_aligned, rctx->src0_row_size, pnr);
|
||||
|
||||
// FARF(HIGH, "rope-prefetch %u: pr %u i1 %u i2 %u i3 %u src-spad %p src-addr %p pnr %u", ith, pir, pi1, i2, i3, src_spad, src_addr, pnr);
|
||||
}
|
||||
|
||||
// Update theta cache
|
||||
if (i2 != prev_i2) {
|
||||
prev_i2 = i2;
|
||||
|
||||
const int32_t p = pos[i2];
|
||||
rope_cache_init(p, rctx->freq_scale, freq_factors, rctx->corr_dims, ne0, rctx->ext_factor, rctx->attn_factor, theta_cache, rctx->theta_scale);
|
||||
|
||||
// FARF(HIGH, "rope-theta %u: ir %u i1 %u i2 %u i3 %u cache %p : usec %u", ith, ir, i1, i2, i3, theta_cache,
|
||||
// (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));
|
||||
}
|
||||
|
||||
// Skip DMA transactions from prev block (if any)
|
||||
// No need to wait for these since the DMA is setup for in-order processing
|
||||
for (uint32_t d=0; d < dma_depth; d++) { dma_queue_pop_nowait(dma_queue); }
|
||||
|
||||
// Compute loop
|
||||
for (uint32_t cnr = 0, cr = 0; cr < nrows; cr += cnr, ir += cnr, i1 += cnr) {
|
||||
// Number of rows to compute
|
||||
cnr = MIN(nrows - cr, HTP_ROPE_SPAD_BLOCK);
|
||||
|
||||
uint8_t * dst_spad = (uint8_t *) dma_queue_pop(dma_queue).src;
|
||||
uint8_t * src_spad = (uint8_t *) dma_queue_pop(dma_queue).dst;
|
||||
|
||||
// FARF(HIGH, "rope-compute %u: ir %u i1 %u i2 %u i3 %u src-spad %p cnr %u : usec %u", ith, ir, i1, i2, i3, src_spad, cnr,
|
||||
// (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));
|
||||
|
||||
if (is_neox) {
|
||||
rope_neox_f32(rctx, dst_spad, src_spad, cnr, ne0, theta_cache);
|
||||
} else {
|
||||
rope_basic_f32(rctx, dst_spad, src_spad, cnr, ne0, theta_cache);
|
||||
}
|
||||
|
||||
uint8_t * dst_addr = (uint8_t *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1;
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(dst_addr, dst_spad), rctx->dst_row_size, rctx->dst_row_size_aligned, cnr);
|
||||
|
||||
// Prefetch more rows (if any)
|
||||
if ((cr + HTP_ROPE_SPAD_NROWS) < nrows) {
|
||||
uint32_t pnr = MIN(nrows - (cr + HTP_ROPE_SPAD_NROWS), HTP_ROPE_SPAD_BLOCK);
|
||||
uint32_t pi1 = i1 + HTP_ROPE_SPAD_NROWS;
|
||||
uint32_t pir = ir + HTP_ROPE_SPAD_NROWS;
|
||||
|
||||
const uint8_t * src_addr = (const uint8_t *) src0->data + i3 * nb03 + i2 * nb02 + pi1 * nb01;
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src_spad, src_addr),
|
||||
rctx->src0_row_size_aligned, rctx->src0_row_size, pnr);
|
||||
|
||||
// FARF(HIGH, "rope-prefetch %u: pr %u i1 %u i2 %u i3 %u src-spad %p src-addr %p pnr %u", ith, pir, pi1, i2, i3, src_spad, src_addr, pnr);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rope_hex_f32(rope_ctx, src0_start_row, src0_end_row, nth, ith, opt_path);
|
||||
done:
|
||||
dma_queue_flush(dma_queue);
|
||||
tt = HAP_perf_get_qtimer_count() - tt;
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "rope-f32: %d/%d/%d: (%u:%u) usec %u\n", ith, nth, opt_path, src0_start_row, src0_end_row,
|
||||
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void rope_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct rope_th_ctx * rope_ctx = (struct rope_th_ctx *) data;
|
||||
|
||||
rope_job_f32_per_thread(rope_ctx, n, i);
|
||||
FARF(HIGH, "rope-f32: %d/%d: (%u:%u) usec %u\n", ith, nth, src0_start_row, src0_end_row, (unsigned) HAP_perf_qtimer_count_to_us(tt));
|
||||
}
|
||||
|
||||
static int execute_op_rope_f32(struct htp_ops_context * octx) {
|
||||
|
|
@ -394,17 +389,10 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) {
|
|||
const struct htp_tensor * src2 = &octx->src2;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
worker_callback_t op_func;
|
||||
const char * op_type = NULL;
|
||||
|
||||
struct rope_th_ctx rope_ctx;
|
||||
const char * op_type = "rope-f32";
|
||||
|
||||
switch (octx->op) {
|
||||
case HTP_OP_ROPE:
|
||||
op_func = rope_job_dispatcher_f32;
|
||||
op_type = "rope-f32";
|
||||
|
||||
init_rope_ctx(&rope_ctx, octx);
|
||||
break;
|
||||
|
||||
default:
|
||||
|
|
@ -415,49 +403,79 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) {
|
|||
const uint32_t n_threads = octx->n_threads;
|
||||
|
||||
const size_t src0_row_size = src0->nb[1];
|
||||
const size_t src1_row_size = src0_row_size;
|
||||
const size_t dst_row_size = dst->nb[1];
|
||||
|
||||
// VTCM scratchpads for all tensors
|
||||
// N rows per thread, padded to HVX vector size
|
||||
octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads;
|
||||
octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
|
||||
octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads;
|
||||
// Aligned row sizes for VTCM
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
const size_t theta_cache_size_aligned = hex_round_up(src0->ne[0] * sizeof(float), 128);
|
||||
|
||||
size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
|
||||
// Calculate spad sizes per thread
|
||||
size_t src0_spad_per_thread = theta_cache_size_aligned + HTP_ROPE_SPAD_NROWS * src0_row_size_aligned;
|
||||
size_t dst_spad_per_thread = HTP_ROPE_SPAD_NROWS * dst_row_size_aligned;
|
||||
size_t spad_per_thread = src0_spad_per_thread + dst_spad_per_thread;
|
||||
|
||||
if (src2->ne[0]) {
|
||||
FARF(HIGH,
|
||||
"%s: %ux%ux%ux%u (x %ux%ux%ux%u x %ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u "
|
||||
"dst-spad-size %u\n",
|
||||
op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
|
||||
src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1], dst->ne[2],
|
||||
dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
|
||||
} else {
|
||||
FARF(HIGH,
|
||||
"%s: %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
|
||||
op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
|
||||
src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
|
||||
octx->dst_spad.size);
|
||||
}
|
||||
|
||||
// Make sure the reserved vtcm size is sufficient
|
||||
if (octx->ctx->vtcm_size < spad_size) {
|
||||
FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
|
||||
spad_size);
|
||||
// Check if we fit in VTCM
|
||||
size_t total_vtcm_needed = spad_per_thread * n_threads;
|
||||
if (octx->ctx->vtcm_size < total_vtcm_needed) {
|
||||
FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, total_vtcm_needed);
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||
// Assign sizes
|
||||
octx->src0_spad.size_per_thread = src0_spad_per_thread;
|
||||
octx->dst_spad.size_per_thread = dst_spad_per_thread;
|
||||
octx->src0_spad.size = n_threads * src0_spad_per_thread;
|
||||
octx->dst_spad.size = n_threads * dst_spad_per_thread;
|
||||
octx->src1_spad.size = 0;
|
||||
|
||||
// Assign pointers
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src1_spad.data = NULL;
|
||||
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
|
||||
// Fill context
|
||||
struct htp_rope_context rctx;
|
||||
memset(&rctx, 0, sizeof(struct htp_rope_context));
|
||||
|
||||
rctx.t_start = HAP_perf_get_qtimer_count();
|
||||
|
||||
rctx.octx = octx;
|
||||
|
||||
const int32_t * op_params = &octx->op_params[0];
|
||||
rctx.n_dims = ((const int32_t *) op_params)[1];
|
||||
rctx.mode = ((const int32_t *) op_params)[2];
|
||||
rctx.n_ctx_orig = ((const int32_t *) op_params)[4];
|
||||
|
||||
memcpy(&rctx.freq_base, (int32_t *) op_params + 5, sizeof(float));
|
||||
memcpy(&rctx.freq_scale, (int32_t *) op_params + 6, sizeof(float));
|
||||
memcpy(&rctx.ext_factor, (int32_t *) op_params + 7, sizeof(float));
|
||||
memcpy(&rctx.attn_factor, (int32_t *) op_params + 8, sizeof(float));
|
||||
memcpy(&rctx.beta_fast, (int32_t *) op_params + 9, sizeof(float));
|
||||
memcpy(&rctx.beta_slow, (int32_t *) op_params + 10, sizeof(float));
|
||||
memcpy(&rctx.sections, (int32_t *) op_params + 11, sizeof(int) * 4);
|
||||
|
||||
rctx.theta_scale = powf(rctx.freq_base, -2.0f / rctx.n_dims);
|
||||
|
||||
rope_corr_dims(rctx.n_dims, rctx.n_ctx_orig, rctx.freq_base, rctx.beta_fast, rctx.beta_slow, rctx.corr_dims);
|
||||
|
||||
rctx.src0_row_size = src0_row_size;
|
||||
rctx.dst_row_size = dst_row_size;
|
||||
rctx.src0_row_size_aligned = src0_row_size_aligned;
|
||||
rctx.dst_row_size_aligned = dst_row_size_aligned;
|
||||
rctx.theta_cache_offset = theta_cache_size_aligned;
|
||||
|
||||
uint32_t ne0 = dst->ne[0];
|
||||
uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
|
||||
rctx.src0_nrows = src0_nrows;
|
||||
|
||||
FARF(HIGH, "rope-f32 n-rows %u n-dims %d ne0 %u ext-factor %.6f theta-scale %.6f attn-factor %.6f\n", rctx.src0_nrows, rctx.n_dims, ne0,
|
||||
rctx.ext_factor, rctx.theta_scale, rctx.attn_factor);
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, op_func, &rope_ctx, n_jobs);
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
rctx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, rope_job_f32, &rctx, n_jobs);
|
||||
}
|
||||
|
||||
return err;
|
||||
|
|
|
|||
|
|
@ -43,11 +43,21 @@
|
|||
\
|
||||
const uint32_t nr = ne01;
|
||||
|
||||
static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
|
||||
struct htp_set_rows_context {
|
||||
struct htp_ops_context * octx;
|
||||
struct fastdiv_values div_ne12;
|
||||
struct fastdiv_values div_ne11;
|
||||
uint32_t src0_nrows_per_thread;
|
||||
};
|
||||
|
||||
static void set_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) {
|
||||
struct htp_set_rows_context * srctx = (struct htp_set_rows_context *)data;
|
||||
struct htp_ops_context * octx = srctx->octx;
|
||||
|
||||
set_rows_preamble;
|
||||
|
||||
// parallelize by rows of src0
|
||||
const uint32_t dr = octx->src0_nrows_per_thread;
|
||||
const uint32_t dr = srctx->src0_nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
|
||||
|
||||
|
|
@ -56,8 +66,8 @@ static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth,
|
|||
for (uint32_t i03 = 0; i03 < ne03; ++i03) {
|
||||
for (uint32_t i02 = 0; i02 < ne02; ++i02) {
|
||||
for (uint32_t i = ir0; i < ir1; ++i) {
|
||||
const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
|
||||
const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
|
||||
const uint32_t i12 = fastmodulo(i03, ne12, &srctx->div_ne12);
|
||||
const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11);
|
||||
const uint32_t i10 = i;
|
||||
|
||||
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
|
||||
|
|
@ -76,15 +86,16 @@ static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth,
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, const int ith) {
|
||||
static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *data) {
|
||||
struct htp_set_rows_context * srctx = (struct htp_set_rows_context *)data;
|
||||
struct htp_ops_context * octx = srctx->octx;
|
||||
|
||||
set_rows_preamble;
|
||||
|
||||
// parallelize by rows of src0
|
||||
const uint32_t dr = octx->src0_nrows_per_thread;
|
||||
const uint32_t dr = srctx->src0_nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
|
||||
|
||||
|
|
@ -93,8 +104,8 @@ static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth,
|
|||
for (uint32_t i03 = 0; i03 < ne03; ++i03) {
|
||||
for (uint32_t i02 = 0; i02 < ne02; ++i02) {
|
||||
for (uint32_t i = ir0; i < ir1; ++i) {
|
||||
const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
|
||||
const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
|
||||
const uint32_t i12 = fastmodulo(i03, ne12, &srctx->div_ne12);
|
||||
const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11);
|
||||
const uint32_t i10 = i;
|
||||
|
||||
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
|
||||
|
|
@ -112,16 +123,6 @@ static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth,
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
static void set_rows_work_f16_f32(unsigned int n, unsigned int i, void *data) {
|
||||
set_rows_thread_f16_f32((struct htp_ops_context *) data, n, i);
|
||||
}
|
||||
|
||||
static void set_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
|
||||
set_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
|
||||
}
|
||||
|
||||
int op_set_rows(struct htp_ops_context * octx) {
|
||||
|
|
@ -143,18 +144,20 @@ int op_set_rows(struct htp_ops_context * octx) {
|
|||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
octx->set_rows_div_ne12 = init_fastdiv_values(ne12);
|
||||
octx->set_rows_div_ne11 = init_fastdiv_values(ne11);
|
||||
struct htp_set_rows_context srctx;
|
||||
srctx.octx = octx;
|
||||
srctx.div_ne12 = init_fastdiv_values(ne12);
|
||||
srctx.div_ne11 = init_fastdiv_values(ne11);
|
||||
|
||||
const uint32_t n_jobs = MIN(nr, octx->n_threads);
|
||||
octx->src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
|
||||
srctx.src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
|
||||
|
||||
switch(octx->dst.type) {
|
||||
case HTP_TYPE_F32:
|
||||
worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f32_f32, octx, n_jobs);
|
||||
worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f32_f32, &srctx, n_jobs);
|
||||
break;
|
||||
case HTP_TYPE_F16:
|
||||
worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f16_f32, octx, n_jobs);
|
||||
worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f16_f32, &srctx, n_jobs);
|
||||
break;
|
||||
default:
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
|
||||
#include "hex-dma.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "hex-fastdiv.h"
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
|
|
@ -48,7 +49,7 @@
|
|||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
struct softmax_th_ctx {
|
||||
struct htp_softmax_context {
|
||||
bool use_f16;
|
||||
bool use_src1;
|
||||
uint32_t n_head;
|
||||
|
|
@ -59,28 +60,48 @@ struct softmax_th_ctx {
|
|||
float m0;
|
||||
float m1;
|
||||
|
||||
uint32_t src0_nrows_per_thread;
|
||||
struct fastdiv_values fastdiv_ne01;
|
||||
struct fastdiv_values fastdiv_ne02;
|
||||
struct fastdiv_values fastdiv_ne12; // For mask broadcasting
|
||||
struct fastdiv_values fastdiv_ne13; // For mask broadcasting
|
||||
size_t spad_stride;
|
||||
|
||||
struct htp_ops_context * octx;
|
||||
};
|
||||
|
||||
static void init_softmax_ctx(struct softmax_th_ctx * softmax_ctx, struct htp_ops_context * octx) {
|
||||
static void init_softmax_ctx(struct htp_softmax_context * smctx, struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
|
||||
memset(softmax_ctx, 0, sizeof(struct softmax_th_ctx));
|
||||
memset(smctx, 0, sizeof(struct htp_softmax_context));
|
||||
|
||||
memcpy(&softmax_ctx->scale, (float *) octx->op_params, sizeof(float));
|
||||
memcpy(&softmax_ctx->max_bias, (float *) octx->op_params + 1, sizeof(float));
|
||||
memcpy(&smctx->scale, (float *) octx->op_params, sizeof(float));
|
||||
memcpy(&smctx->max_bias, (float *) octx->op_params + 1, sizeof(float));
|
||||
|
||||
softmax_ctx->n_head = src0->ne[2];
|
||||
softmax_ctx->n_head_log2 = 1u << (uint32_t) floor(log2(softmax_ctx->n_head));
|
||||
smctx->n_head = src0->ne[2];
|
||||
smctx->n_head_log2 = 1u << (uint32_t) floor(log2(smctx->n_head));
|
||||
|
||||
softmax_ctx->m0 = powf(2.0f, -(softmax_ctx->max_bias) / softmax_ctx->n_head_log2);
|
||||
softmax_ctx->m1 = powf(2.0f, -(softmax_ctx->max_bias / 2.0f) / softmax_ctx->n_head_log2);
|
||||
smctx->m0 = powf(2.0f, -(smctx->max_bias) / smctx->n_head_log2);
|
||||
smctx->m1 = powf(2.0f, -(smctx->max_bias / 2.0f) / smctx->n_head_log2);
|
||||
|
||||
softmax_ctx->use_src1 = (src1->ne[0] != 0);
|
||||
softmax_ctx->use_f16 = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16);
|
||||
smctx->use_src1 = (src1->ne[0] != 0);
|
||||
smctx->use_f16 = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16);
|
||||
|
||||
softmax_ctx->octx = octx;
|
||||
smctx->octx = octx;
|
||||
|
||||
// Initialize fastdiv values
|
||||
const uint32_t ne01 = src0->ne[1];
|
||||
const uint32_t ne02 = src0->ne[2];
|
||||
|
||||
if (ne01 > 0) smctx->fastdiv_ne01 = init_fastdiv_values(ne01);
|
||||
if (ne02 > 0) smctx->fastdiv_ne02 = init_fastdiv_values(ne02);
|
||||
|
||||
const uint32_t ne12 = (src1->ne[0]) ? src1->ne[2] : 1;
|
||||
const uint32_t ne13 = (src1->ne[0]) ? src1->ne[3] : 1;
|
||||
|
||||
if (ne12 > 0) smctx->fastdiv_ne12 = init_fastdiv_values(ne12);
|
||||
if (ne13 > 0) smctx->fastdiv_ne13 = init_fastdiv_values(ne13);
|
||||
}
|
||||
|
||||
static void hvx_fast_softmax_prep_f32(const uint8_t * restrict src,
|
||||
|
|
@ -139,8 +160,7 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src,
|
|||
max_vec = Q6_Vsf_vmax_VsfVsf(max_vec, v1);
|
||||
}
|
||||
|
||||
HVX_Vector v = hvx_vec_reduce_max_f32(max_vec);
|
||||
max_vec = hvx_vec_repl4(v);
|
||||
max_vec = hvx_vec_reduce_max_f32(max_vec); // replicated over all lanes
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < step_of_1; i++) {
|
||||
|
|
@ -154,8 +174,7 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src,
|
|||
v_pad[i] = v3;
|
||||
}
|
||||
|
||||
v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_vec));
|
||||
sum_vec = hvx_vec_repl4(v);
|
||||
sum_vec = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_vec)); // replicated over all lanes
|
||||
|
||||
HVX_VectorPred pos_sum = Q6_Q_vcmp_gt_VwVw(sum_vec, zero_v);
|
||||
HVX_Vector v4 = hvx_vec_inverse_f32(sum_vec);
|
||||
|
|
@ -183,83 +202,9 @@ static float hvx_softmax_f32(const uint8_t * restrict src,
|
|||
return sum;
|
||||
}
|
||||
|
||||
static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ctx, int opt_path) {
|
||||
struct htp_ops_context * octx = softmax_ctx->octx;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
const struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
htp_softmax_preamble3;
|
||||
|
||||
uint8_t * src0_spad_data = octx->src0_spad.data + (ith * nb01);
|
||||
uint8_t * src1_spad_data = octx->src1_spad.data + (ith * nb01);
|
||||
uint8_t * dst_spad_data = octx->dst_spad.data + (ith * nb1);
|
||||
|
||||
float * wp0 = (float *) src0_spad_data;
|
||||
float * wp1 = (float *) src1_spad_data;
|
||||
float * wp2 = (float *) dst_spad_data;
|
||||
|
||||
for (uint32_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (uint32_t i02 = 0; i02 < ne02; i02++) {
|
||||
for (uint32_t i01 = ith; i01 < ne01; i01 += nth) {
|
||||
const uint32_t i11 = i01;
|
||||
const uint32_t i12 = i02 % ne12;
|
||||
const uint32_t i13 = i03 % ne13;
|
||||
|
||||
// ALiBi
|
||||
const uint32_t h = i02; // head
|
||||
|
||||
const float slope = (softmax_ctx->max_bias > 0.0f) ?
|
||||
h < softmax_ctx->n_head_log2 ?
|
||||
powf(softmax_ctx->m0, h + 1) :
|
||||
powf(softmax_ctx->m1, 2 * (h - softmax_ctx->n_head_log2) + 1) :
|
||||
1.0f;
|
||||
|
||||
float * sp = (float *) ((char *) octx->src0.data + i01 * nb01 + i02 * nb02 + i03 * nb03);
|
||||
float * dp = (float *) ((char *) octx->dst.data + i01 * nb1 + i02 * nb2 + i03 * nb3);
|
||||
|
||||
// broadcast the mask across rows
|
||||
__fp16 * mp_f16 = (softmax_ctx->use_src1) ?
|
||||
(__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
|
||||
NULL;
|
||||
float * mp_f32 = (softmax_ctx->use_src1) ?
|
||||
(float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
|
||||
NULL;
|
||||
|
||||
if ((1 == opt_path) && (mp_f32) && !(softmax_ctx->use_f16)) {
|
||||
hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale,
|
||||
(const uint8_t *) mp_f32, slope);
|
||||
} else {
|
||||
hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, softmax_ctx->scale);
|
||||
if (mp_f32) {
|
||||
if (softmax_ctx->use_f16) {
|
||||
for (int i = 0; i < ne00; ++i) {
|
||||
wp0[i] += slope * (float) mp_f16[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < ne00; ++i) {
|
||||
wp0[i] += slope * mp_f32[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00);
|
||||
} else {
|
||||
float max = hvx_reduce_max_f32((const uint8_t *) wp0, ne00);
|
||||
float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max);
|
||||
sum = sum > 0.0 ? (1.0 / sum) : 1;
|
||||
hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int nth, int ith) {
|
||||
struct htp_ops_context * octx = softmax_ctx->octx;
|
||||
static void softmax_job_f32(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_softmax_context * smctx = (struct htp_softmax_context *) data;
|
||||
struct htp_ops_context * octx = smctx->octx;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
|
|
@ -268,7 +213,7 @@ static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int
|
|||
htp_softmax_preamble3;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
|
||||
const uint32_t src0_nrows_per_thread = smctx->src0_nrows_per_thread;
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
|
|
@ -291,20 +236,103 @@ static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int
|
|||
opt_path = 1;
|
||||
}
|
||||
|
||||
softmax_htp_f32(nth, ith, softmax_ctx, opt_path);
|
||||
uint8_t * src0_spad_data = octx->src0_spad.data + (ith * smctx->spad_stride);
|
||||
uint8_t * src1_spad_data = octx->src1_spad.data + (ith * smctx->spad_stride);
|
||||
uint8_t * dst_spad_data = octx->dst_spad.data + (ith * smctx->spad_stride);
|
||||
|
||||
float * wp0 = (float *) src0_spad_data;
|
||||
float * wp1 = (float *) src1_spad_data;
|
||||
float * wp2 = (float *) dst_spad_data;
|
||||
|
||||
uint32_t prev_i2 = (uint32_t)-1;
|
||||
float slope = 1.0f;
|
||||
|
||||
for (uint32_t r = src0_start_row; r < src0_end_row; ++r) {
|
||||
uint32_t i1 = fastmodulo(r, ne01, &smctx->fastdiv_ne01);
|
||||
uint32_t r_div_ne01 = fastdiv(r, &smctx->fastdiv_ne01);
|
||||
uint32_t i2 = fastmodulo(r_div_ne01, ne02, &smctx->fastdiv_ne02);
|
||||
uint32_t i3 = fastdiv(r_div_ne01, &smctx->fastdiv_ne02);
|
||||
|
||||
// Map to original logic indices
|
||||
// i01 = i1
|
||||
// i02 = i2
|
||||
// i03 = i3
|
||||
|
||||
const uint32_t i11 = i1;
|
||||
// const uint32_t i12 = i2 % ne12;
|
||||
// const uint32_t i13 = i3 % ne13;
|
||||
|
||||
uint32_t i12, i13;
|
||||
if (ne12 == ne02) {
|
||||
i12 = i2;
|
||||
} else {
|
||||
i12 = fastmodulo(i2, ne12, &smctx->fastdiv_ne12);
|
||||
}
|
||||
|
||||
if (ne13 == ne03) {
|
||||
i13 = i3;
|
||||
} else {
|
||||
i13 = fastmodulo(i3, ne13, &smctx->fastdiv_ne13);
|
||||
}
|
||||
|
||||
// ALiBi
|
||||
if (i2 != prev_i2) {
|
||||
const uint32_t h = i2; // head
|
||||
|
||||
slope = (smctx->max_bias > 0.0f) ?
|
||||
h < smctx->n_head_log2 ?
|
||||
powf(smctx->m0, h + 1) :
|
||||
powf(smctx->m1, 2 * (h - smctx->n_head_log2) + 1) :
|
||||
1.0f;
|
||||
prev_i2 = i2;
|
||||
}
|
||||
|
||||
float * sp = (float *) ((char *) octx->src0.data + i1 * nb01 + i2 * nb02 + i3 * nb03);
|
||||
float * dp = (float *) ((char *) octx->dst.data + i1 * nb1 + i2 * nb2 + i3 * nb3);
|
||||
|
||||
// broadcast the mask across rows
|
||||
__fp16 * mp_f16 = (smctx->use_src1) ?
|
||||
(__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
|
||||
NULL;
|
||||
float * mp_f32 = (smctx->use_src1) ?
|
||||
(float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
|
||||
NULL;
|
||||
|
||||
if ((1 == opt_path) && (mp_f32) && !(smctx->use_f16)) {
|
||||
hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, smctx->scale,
|
||||
(const uint8_t *) mp_f32, slope);
|
||||
} else {
|
||||
hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, smctx->scale);
|
||||
if (mp_f32) {
|
||||
if (smctx->use_f16) {
|
||||
for (int i = 0; i < ne00; ++i) {
|
||||
wp0[i] += slope * (float) mp_f16[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < ne00; ++i) {
|
||||
wp0[i] += slope * mp_f32[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00);
|
||||
} else {
|
||||
float max = hvx_reduce_max_f32((const uint8_t *) wp0, ne00);
|
||||
float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max);
|
||||
sum = sum > 0.0 ? (1.0 / sum) : 1;
|
||||
hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum);
|
||||
}
|
||||
}
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "softmax-f32 %d/%d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
|
||||
softmax_ctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13,
|
||||
smctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13,
|
||||
ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void softmax_job_dispatcher_f32(unsigned int n, unsigned int i, void * p_data) {
|
||||
struct softmax_th_ctx * p_softmax_ctx = (struct softmax_th_ctx *) p_data;
|
||||
softmax_job_f32_per_thread(p_softmax_ctx, n, i);
|
||||
}
|
||||
|
||||
static int execute_op_softmax_f32(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
|
|
@ -312,17 +340,12 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) {
|
|||
const struct htp_tensor * src1 = &octx->src1;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
worker_callback_t op_func;
|
||||
const char * op_type = NULL;
|
||||
|
||||
struct softmax_th_ctx softmax_ctx;
|
||||
struct htp_softmax_context smctx;
|
||||
const char * op_type = "softmax-f32";
|
||||
|
||||
switch (octx->op) {
|
||||
case HTP_OP_SOFTMAX:
|
||||
op_func = softmax_job_dispatcher_f32;
|
||||
op_type = "softmax-f32";
|
||||
|
||||
init_softmax_ctx(&softmax_ctx, octx);
|
||||
init_softmax_ctx(&smctx, octx);
|
||||
break;
|
||||
|
||||
default:
|
||||
|
|
@ -342,6 +365,9 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) {
|
|||
octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
|
||||
octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads;
|
||||
|
||||
// Use stride for calculating offset
|
||||
smctx.spad_stride = hex_round_up(src0_row_size, 128);
|
||||
|
||||
size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
|
||||
|
||||
if (src1->ne[0]) {
|
||||
|
|
@ -371,8 +397,8 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) {
|
|||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, op_func, &softmax_ctx, n_jobs);
|
||||
smctx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_jobs);
|
||||
}
|
||||
|
||||
return err;
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@
|
|||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
|
||||
#define sum_rows_preamble \
|
||||
struct htp_tensor *src0 = &octx->src0;\
|
||||
struct htp_tensor *dst = &octx->dst; \
|
||||
|
|
@ -42,53 +41,54 @@
|
|||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3]; \
|
||||
|
||||
static int sum_rows_thread_f32(struct htp_ops_context * octx, const int nth, const int ith) {
|
||||
sum_rows_preamble;
|
||||
struct sum_rows_context {
|
||||
const uint8_t * src_data;
|
||||
uint8_t * dst_data;
|
||||
uint32_t ne00;
|
||||
size_t src_stride;
|
||||
size_t dst_stride;
|
||||
uint32_t rows_per_thread;
|
||||
uint32_t total_rows;
|
||||
bool opt_path;
|
||||
};
|
||||
|
||||
const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
|
||||
const size_t src0_row_size = nb01;
|
||||
const size_t dst_row_size = nb1;
|
||||
static void sum_rows_thread_f32(unsigned int nth, unsigned int ith, void *data) {
|
||||
const struct sum_rows_context * smctx = (const struct sum_rows_context *) data;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
const uint32_t rows_per_thread = smctx->rows_per_thread;
|
||||
const uint32_t total_rows = smctx->total_rows;
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
const uint32_t start_row = rows_per_thread * ith;
|
||||
const uint32_t end_row = MIN(start_row + rows_per_thread, total_rows);
|
||||
|
||||
// no work for this thread
|
||||
if (src0_start_row >= src0_end_row) {
|
||||
return HTP_STATUS_OK;
|
||||
if (start_row >= end_row) {
|
||||
return;
|
||||
}
|
||||
|
||||
int opt_path = 0;
|
||||
if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) {
|
||||
opt_path = 1;
|
||||
}
|
||||
const size_t src_stride = smctx->src_stride;
|
||||
const size_t dst_stride = smctx->dst_stride;
|
||||
const uint32_t ne00 = smctx->ne00;
|
||||
const bool opt_path = smctx->opt_path;
|
||||
|
||||
const uint8_t * restrict data_src = (const uint8_t *) src0->data;
|
||||
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||
const float * restrict src_th = (const float *) (smctx->src_data + (start_row * src_stride));
|
||||
float * restrict dst_th = (float *) (smctx->dst_data + (start_row * dst_stride));
|
||||
|
||||
const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size));
|
||||
float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size));
|
||||
// Calculate actual number of rows for this thread
|
||||
const uint32_t n_rows = end_row - start_row;
|
||||
|
||||
for (uint32_t ir = 0; ir < src0_nrows_per_thread; ir++) {
|
||||
const float * restrict src_local = src_th + (ir * ne00);
|
||||
for (uint32_t ir = 0; ir < n_rows; ir++) {
|
||||
const float * restrict src_local = src_th + (ir * (src_stride / sizeof(float)));
|
||||
|
||||
if (ir + 1 < src0_nrows_per_thread) {
|
||||
hex_l2fetch(src_local + ne00, src0_row_size, src0_row_size, 1);
|
||||
if (ir + 1 < n_rows) {
|
||||
hex_l2fetch(src_local + (src_stride / sizeof(float)), src_stride, src_stride, 1);
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
if (opt_path) {
|
||||
dst_th[ir] = hvx_reduce_sum_f32_a((const uint8_t *) src_local, ne00);
|
||||
} else {
|
||||
dst_th[ir] = hvx_reduce_sum_f32((const uint8_t *) src_local, ne00);
|
||||
}
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
static void sum_rows_work_f32(unsigned int n, unsigned int i, void *data) {
|
||||
sum_rows_thread_f32((struct htp_ops_context *) data, n, i);
|
||||
}
|
||||
|
||||
int op_sum_rows(struct htp_ops_context * octx) {
|
||||
|
|
@ -106,10 +106,25 @@ int op_sum_rows(struct htp_ops_context * octx) {
|
|||
const uint32_t src0_nrows = ne01 * ne02 * ne03;
|
||||
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
uint32_t rows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, sum_rows_work_f32, octx, n_jobs);
|
||||
bool opt_path = false;
|
||||
if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) {
|
||||
opt_path = true;
|
||||
}
|
||||
|
||||
struct sum_rows_context smctx = {
|
||||
.src_data = (const uint8_t *) src0->data,
|
||||
.dst_data = (uint8_t *) dst->data,
|
||||
.ne00 = ne00,
|
||||
.src_stride = nb01,
|
||||
.dst_stride = nb1,
|
||||
.rows_per_thread = rows_per_thread,
|
||||
.total_rows = src0_nrows,
|
||||
.opt_path = opt_path,
|
||||
};
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, sum_rows_thread_f32, &smctx, n_jobs);
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,28 @@
|
|||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
struct htp_unary_context {
|
||||
struct htp_ops_context * octx;
|
||||
|
||||
// Precomputed values
|
||||
const uint8_t * data_src0;
|
||||
uint8_t * data_dst;
|
||||
|
||||
size_t src0_row_size;
|
||||
size_t dst_row_size;
|
||||
|
||||
size_t src0_row_size_aligned;
|
||||
size_t dst_row_size_aligned;
|
||||
|
||||
size_t src0_spad_half_size;
|
||||
size_t dst_spad_half_size;
|
||||
|
||||
uint32_t block;
|
||||
uint32_t src0_nrows;
|
||||
uint32_t src0_nrows_per_thread;
|
||||
uint32_t nc;
|
||||
};
|
||||
|
||||
#define htp_unary_preamble \
|
||||
const uint32_t ne00 = src->ne[0]; \
|
||||
const uint32_t ne01 = src->ne[1]; \
|
||||
|
|
@ -57,8 +79,7 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
|
|||
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
|
||||
}
|
||||
|
||||
HVX_Vector reduced_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v));
|
||||
sum_v = hvx_vec_repl4(reduced_sum);
|
||||
sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); // replicated over all lanes
|
||||
|
||||
HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems);
|
||||
HVX_Vector denom_v = hvx_vec_inverse_f32(t_v);
|
||||
|
|
@ -75,128 +96,95 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
|
|||
}
|
||||
}
|
||||
|
||||
static void scale_htp_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params,
|
||||
int opt_path) {
|
||||
static void scale_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params) {
|
||||
float scale = 0.f;
|
||||
float bias = 0.f;
|
||||
memcpy(&scale, &op_params[0], sizeof(float));
|
||||
memcpy(&bias, &op_params[1], sizeof(float));
|
||||
|
||||
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
||||
const float * restrict src_local = src + (ir * row_elems);
|
||||
float * restrict dst_local = dst + (ir * row_elems);
|
||||
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
||||
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
||||
|
||||
if (ir + 1 < num_rows) {
|
||||
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
|
||||
}
|
||||
|
||||
hvx_scale_offset_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias);
|
||||
hvx_scale_offset_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias);
|
||||
}
|
||||
}
|
||||
|
||||
static void rms_norm_htp_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params,
|
||||
int opt_path) {
|
||||
static void rms_norm_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params) {
|
||||
float epsilon = 0.f;
|
||||
memcpy(&epsilon, op_params, sizeof(float));
|
||||
|
||||
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
||||
const float * restrict src_local = src + (ir * row_elems);
|
||||
float * restrict dst_local = dst + (ir * row_elems);
|
||||
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
||||
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
||||
|
||||
if (ir + 1 < num_rows) {
|
||||
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon);
|
||||
} else {
|
||||
float sum = hvx_sum_of_squares_f32((const uint8_t *) src_local, row_elems);
|
||||
|
||||
const float mean = sum / row_elems;
|
||||
const float scale = 1.0f / sqrtf(mean + epsilon);
|
||||
|
||||
hvx_scale_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale);
|
||||
}
|
||||
hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon);
|
||||
}
|
||||
}
|
||||
|
||||
static void sqr_htp_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params,
|
||||
int opt_path) {
|
||||
static void sqr_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params) {
|
||||
|
||||
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
||||
const float * restrict src_local = src + (ir * row_elems);
|
||||
float * restrict dst_local = dst + (ir * row_elems);
|
||||
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
||||
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
||||
|
||||
if (ir + 1 < num_rows) {
|
||||
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
|
||||
} else {
|
||||
hvx_sqr_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
|
||||
}
|
||||
hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
|
||||
}
|
||||
}
|
||||
|
||||
static void sqrt_htp_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params,
|
||||
int opt_path) {
|
||||
static void sqrt_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params) {
|
||||
|
||||
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
||||
const float * restrict src_local = src + (ir * row_elems);
|
||||
float * restrict dst_local = dst + (ir * row_elems);
|
||||
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
||||
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
||||
|
||||
if (ir + 1 < num_rows) {
|
||||
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
|
||||
} else {
|
||||
hvx_sqrt_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
|
||||
}
|
||||
hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
|
||||
}
|
||||
}
|
||||
|
||||
static void unary_job_f32_per_thread(const struct htp_tensor * src,
|
||||
struct htp_tensor * dst,
|
||||
uint8_t * spad,
|
||||
int htp_op,
|
||||
int32_t * op_params,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread) {
|
||||
static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
const struct htp_unary_context * uctx = (const struct htp_unary_context *) data;
|
||||
struct htp_ops_context * octx = uctx->octx;
|
||||
const struct htp_tensor * src = &octx->src0;
|
||||
const struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
htp_unary_preamble;
|
||||
|
||||
const size_t src0_row_size = nb01;
|
||||
const size_t dst_row_size = nb1;
|
||||
int htp_op = octx->op;
|
||||
int32_t * op_params = octx->op_params;
|
||||
uint32_t src0_nrows_per_thread = uctx->src0_nrows_per_thread;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
const size_t src0_row_size = uctx->src0_row_size;
|
||||
const size_t dst_row_size = uctx->dst_row_size;
|
||||
|
||||
const size_t src0_row_size_aligned = uctx->src0_row_size_aligned;
|
||||
const size_t dst_row_size_aligned = uctx->dst_row_size_aligned;
|
||||
|
||||
const uint32_t src0_nrows = uctx->src0_nrows;
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
|
||||
|
|
@ -208,79 +196,104 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src,
|
|||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
int is_aligned = 1;
|
||||
int opt_path = 0;
|
||||
if ((0 == hex_is_aligned((void *) src->data, VLEN)) || (0 == hex_is_aligned((void *) dst->data, VLEN))) {
|
||||
is_aligned = 0;
|
||||
}
|
||||
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
|
||||
opt_path = 1;
|
||||
const uint8_t * restrict data_src = uctx->data_src0;
|
||||
uint8_t * restrict data_dst = uctx->data_dst;
|
||||
|
||||
uint8_t * src0_spad_data = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
|
||||
uint8_t * dst_spad_data = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
|
||||
|
||||
size_t src0_spad_half_size = uctx->src0_spad_half_size;
|
||||
size_t dst_spad_half_size = uctx->dst_spad_half_size;
|
||||
|
||||
const int BLOCK = uctx->block;
|
||||
if (BLOCK == 0) {
|
||||
FARF(ERROR, "unary-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
|
||||
octx->src0_spad.size_per_thread, src0_row_size_aligned);
|
||||
return;
|
||||
}
|
||||
|
||||
const uint8_t * restrict data_src = (const uint8_t *) src->data;
|
||||
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||
dma_queue * dma_queue = octx->ctx->dma[ith];
|
||||
|
||||
const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size));
|
||||
float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size));
|
||||
uint8_t * restrict spad_th = (uint8_t *) spad + (ith * nb01);
|
||||
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
|
||||
switch (htp_op) {
|
||||
case HTP_OP_RMS_NORM:
|
||||
rms_norm_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
|
||||
break;
|
||||
case HTP_OP_SCALE:
|
||||
scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
|
||||
break;
|
||||
case HTP_OP_SQR:
|
||||
sqr_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
|
||||
break;
|
||||
case HTP_OP_SQRT:
|
||||
sqrt_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
|
||||
break;
|
||||
// Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
|
||||
dst_row_size, dst_row_size_aligned, 0);
|
||||
|
||||
default:
|
||||
break;
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue,
|
||||
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + (ir * src0_row_size)),
|
||||
src0_row_size_aligned, src0_row_size, block_size);
|
||||
}
|
||||
|
||||
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
|
||||
float * dst_spad = (float *) dma_queue_pop(dma_queue).src;
|
||||
float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
|
||||
|
||||
// Process block in VTCM
|
||||
switch (htp_op) {
|
||||
case HTP_OP_RMS_NORM:
|
||||
rms_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
||||
break;
|
||||
case HTP_OP_SCALE:
|
||||
scale_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
||||
break;
|
||||
case HTP_OP_SQR:
|
||||
sqr_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
||||
break;
|
||||
case HTP_OP_SQRT:
|
||||
sqrt_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad),
|
||||
dst_row_size, dst_row_size_aligned, block_size);
|
||||
|
||||
// prefetch N+2 loop iteration if any
|
||||
const uint32_t pref_block = (ir + BLOCK * 2);
|
||||
if (pref_block < src0_end_row) {
|
||||
const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue,
|
||||
dma_make_ptr(src0_spad, data_src + (pref_block * src0_row_size)),
|
||||
src0_row_size_aligned, src0_row_size, pref_block_size);
|
||||
}
|
||||
}
|
||||
|
||||
dma_queue_flush(dma_queue);
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "unary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, src->ne[0],
|
||||
FARF(HIGH, "unary-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, src->ne[0],
|
||||
src->ne[1], src->ne[2], src->ne[3], src0_start_row, src0_end_row, dst->ne[0], dst->ne[1], dst->ne[2],
|
||||
dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void unary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
|
||||
unary_job_f32_per_thread(&octx->src0, &octx->dst, octx->src0_spad.data, octx->op, octx->op_params, n, i,
|
||||
octx->src0_nrows_per_thread);
|
||||
}
|
||||
|
||||
static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
worker_callback_t unary_op_func;
|
||||
const char * op_type = NULL;
|
||||
const char * op_type = NULL;
|
||||
|
||||
switch (octx->op) {
|
||||
case HTP_OP_RMS_NORM:
|
||||
unary_op_func = unary_job_dispatcher_f32;
|
||||
op_type = "rmsnorm-f32";
|
||||
op_type = "rmsnorm-f32";
|
||||
break;
|
||||
case HTP_OP_SCALE:
|
||||
unary_op_func = unary_job_dispatcher_f32;
|
||||
op_type = "scale-f32";
|
||||
op_type = "scale-f32";
|
||||
break;
|
||||
case HTP_OP_SQR:
|
||||
unary_op_func = unary_job_dispatcher_f32;
|
||||
op_type = "sqr-f32";
|
||||
op_type = "sqr-f32";
|
||||
break;
|
||||
case HTP_OP_SQRT:
|
||||
unary_op_func = unary_job_dispatcher_f32;
|
||||
op_type = "sqrt-f32";
|
||||
op_type = "sqrt-f32";
|
||||
break;
|
||||
|
||||
default:
|
||||
|
|
@ -294,32 +307,61 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
|||
const size_t src0_row_size = src0->nb[1];
|
||||
const size_t dst_row_size = dst->nb[1];
|
||||
|
||||
// VTCM scratchpads for all tensors
|
||||
octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads;
|
||||
octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
|
||||
size_t spad_size = octx->src0_spad.size + octx->dst_spad.size;
|
||||
// VTCM scratchpads for all tensors
|
||||
// N rows per thread, padded to HVX vector size
|
||||
// Double buffering requires 2x size per buffer
|
||||
|
||||
size_t spad_size_per_row = 2 * (src0_row_size_aligned + dst_row_size_aligned);
|
||||
size_t vtcm_row_per_thread = (octx->ctx->vtcm_size)/ (n_threads * spad_size_per_row);
|
||||
|
||||
// Make sure the reserved vtcm size is sufficient
|
||||
if (vtcm_row_per_thread == 0) {
|
||||
FARF(ERROR, "unary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
|
||||
spad_size_per_row * n_threads);
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
octx->src0_spad.size_per_thread = src0_row_size_aligned * vtcm_row_per_thread * 2;
|
||||
octx->dst_spad.size_per_thread = dst_row_size_aligned * vtcm_row_per_thread * 2;
|
||||
|
||||
octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread;
|
||||
octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread;
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
|
||||
FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type,
|
||||
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
|
||||
|
||||
// Make sure the reserved vtcm size is sufficient
|
||||
if (octx->ctx->vtcm_size < spad_size) {
|
||||
FARF(ERROR, "unary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
|
||||
spad_size);
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
|
||||
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
struct htp_unary_context uctx = {
|
||||
.octx = octx,
|
||||
.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs,
|
||||
.src0_nrows = src0_nrows,
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, unary_op_func, octx, n_jobs);
|
||||
.data_src0 = (const uint8_t *)src0->data,
|
||||
.data_dst = (uint8_t *)dst->data,
|
||||
|
||||
.src0_row_size = src0_row_size,
|
||||
.dst_row_size = dst_row_size,
|
||||
|
||||
.src0_row_size_aligned = src0_row_size_aligned,
|
||||
.dst_row_size_aligned = dst_row_size_aligned,
|
||||
|
||||
.src0_spad_half_size = octx->src0_spad.size_per_thread / 2,
|
||||
.dst_spad_half_size = octx->dst_spad.size_per_thread / 2,
|
||||
|
||||
.block = (octx->src0_spad.size_per_thread / 2) / src0_row_size_aligned,
|
||||
.nc = src0->ne[0],
|
||||
};
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, unary_job_f32_per_thread, &uctx, n_jobs);
|
||||
}
|
||||
|
||||
return err;
|
||||
|
|
|
|||
|
|
@ -54,6 +54,6 @@ adb $adbserial $adbhost shell " \
|
|||
$verbose $experimental $sched $opmask $profile $nhvx $ndev $hb \
|
||||
./$branch/bin/llama-cli --no-mmap -m $basedir/../gguf/$model \
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
|
||||
--ctx-size 8192 --batch-size 128 -fa on \
|
||||
-ngl 99 --device $device $cli_opts $@ \
|
||||
--ctx-size 8192 --ubatch-size 256 -fa on \
|
||||
-ngl 99 --device $device $cli_opts $@ \
|
||||
"
|
||||
|
|
|
|||
|
|
@ -54,6 +54,6 @@ adb $adbserial $adbhost shell " \
|
|||
$verbose $experimental $sched $opmask $profile $nhvx $ndev $hb \
|
||||
./$branch/bin/llama-completion --no-mmap -m $basedir/../gguf/$model \
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
|
||||
--ctx-size 8192 --batch-size 128 -fa on \
|
||||
-ngl 99 -no-cnv --device $device $cli_opts $@ \
|
||||
--ctx-size 8192 --ubatch-size 256 -fa on \
|
||||
-ngl 99 -no-cnv --device $device $cli_opts $@ \
|
||||
"
|
||||
|
|
|
|||
|
|
@ -58,11 +58,11 @@ adb $adbserial $adbhost shell " \
|
|||
cd $basedir; ulimit -c unlimited; \
|
||||
LD_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
$verbose $experimental $sched $opmask $profile $nhvx $ndev $mtmd_backend \
|
||||
./$branch/bin/llama-mtmd-cli --no-mmap -m $basedir/../gguf/$model \
|
||||
--mmproj $basedir/../gguf/$mmproj \
|
||||
--image $basedir/../gguf/$image \
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
|
||||
--ctx-size 8192 --batch-size 128 -ctk q8_0 -ctv q8_0 -fa on \
|
||||
-ngl 99 --device $device -v $cli_opts $@ \
|
||||
$verbose $experimental $sched $opmask $profile $nhvx $ndev $mtmd_backend \
|
||||
./$branch/bin/llama-mtmd-cli --no-mmap -m $basedir/../gguf/$model \
|
||||
--mmproj $basedir/../gguf/$mmproj \
|
||||
--image $basedir/../gguf/$image \
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
|
||||
--ctx-size 8192 --ubatch-size 256 -fa on \
|
||||
-ngl 99 --device $device -v $cli_opts $@ \
|
||||
"
|
||||
|
|
|
|||
|
|
@ -49,5 +49,5 @@ $env:ADSP_LIBRARY_PATH="$basedir\lib"
|
|||
& "$basedir\bin\llama-completion.exe" `
|
||||
--no-mmap -no-cnv -m $basedir\..\..\gguf\$model `
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 `
|
||||
--ctx-size 8192 --batch-size 128 -ctk q8_0 -ctv q8_0 -fa on `
|
||||
--ctx-size 8192 --ubatch-size 128 -fa on `
|
||||
-ngl 99 --device $device $cli_opts
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import os
|
|||
import sys
|
||||
import subprocess
|
||||
|
||||
HTTPLIB_VERSION = "d4180e923f846b44a3d30acd938438d6e64fc9f6"
|
||||
HTTPLIB_VERSION = "refs/tags/v0.34.0"
|
||||
|
||||
vendor = {
|
||||
"https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp",
|
||||
|
|
|
|||
|
|
@ -2440,64 +2440,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
|||
// TODO: add more model-specific info which should prevent loading the session file if not identical
|
||||
}
|
||||
|
||||
// write output ids
|
||||
{
|
||||
LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__);
|
||||
|
||||
const auto n_outputs = this->n_outputs;
|
||||
const auto & output_ids = this->output_ids;
|
||||
|
||||
std::vector<int32_t> w_output_pos;
|
||||
|
||||
w_output_pos.resize(n_outputs);
|
||||
|
||||
// build a more compact representation of the output ids
|
||||
for (size_t i = 0; i < n_batch(); ++i) {
|
||||
// map an output id to a position in the batch
|
||||
int64_t pos = output_ids[i];
|
||||
if (pos >= 0) {
|
||||
GGML_ASSERT(pos < n_outputs);
|
||||
w_output_pos[pos] = i;
|
||||
}
|
||||
}
|
||||
|
||||
io.write(&n_outputs, sizeof(n_outputs));
|
||||
|
||||
if (n_outputs) {
|
||||
io.write(w_output_pos.data(), n_outputs * sizeof(int32_t));
|
||||
}
|
||||
}
|
||||
|
||||
// [TAG_CONTEXT_STATE_LOGITS]
|
||||
// write logits
|
||||
{
|
||||
LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__);
|
||||
|
||||
const uint64_t logits_size = std::min((uint64_t) this->logits.size, (uint64_t) n_outputs * model.vocab.n_tokens());
|
||||
|
||||
io.write(&logits_size, sizeof(logits_size));
|
||||
|
||||
if (logits_size) {
|
||||
io.write(logits.data, logits_size * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
// write embeddings
|
||||
{
|
||||
LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__);
|
||||
|
||||
const uint64_t embd_size = std::min((uint64_t) this->embd.size, (uint64_t) n_outputs * model.hparams.n_embd);
|
||||
|
||||
io.write(&embd_size, sizeof(embd_size));
|
||||
|
||||
if (embd_size) {
|
||||
io.write(embd.data, embd_size * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: handle sampling buffers and samplers state ?
|
||||
// https://github.com/ggml-org/llama.cpp/pull/17004
|
||||
|
||||
if (memory != nullptr) {
|
||||
LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__);
|
||||
memory->state_write(io);
|
||||
|
|
@ -2523,70 +2465,6 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
|||
// TODO: add more info which needs to be identical but which is not verified otherwise
|
||||
}
|
||||
|
||||
// read output ids
|
||||
{
|
||||
LLAMA_LOG_DEBUG("%s: - reading output ids\n", __func__);
|
||||
|
||||
auto n_outputs = this->n_outputs;
|
||||
io.read_to(&n_outputs, sizeof(n_outputs));
|
||||
|
||||
if (n_outputs > output_reserve(n_outputs)) {
|
||||
throw std::runtime_error("could not reserve outputs");
|
||||
}
|
||||
|
||||
std::vector<int32_t> output_pos;
|
||||
|
||||
if (n_outputs) {
|
||||
output_pos.resize(n_outputs);
|
||||
io.read_to(output_pos.data(), n_outputs * sizeof(int32_t));
|
||||
|
||||
for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
|
||||
int32_t id = output_pos[i];
|
||||
if ((uint32_t) id >= n_batch()) {
|
||||
throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, n_batch()));
|
||||
}
|
||||
this->output_ids[id] = i;
|
||||
}
|
||||
|
||||
this->n_outputs = n_outputs;
|
||||
}
|
||||
}
|
||||
|
||||
// read logits
|
||||
{
|
||||
LLAMA_LOG_DEBUG("%s: - reading logits\n", __func__);
|
||||
|
||||
uint64_t logits_size;
|
||||
io.read_to(&logits_size, sizeof(logits_size));
|
||||
|
||||
if (this->logits.size < logits_size) {
|
||||
throw std::runtime_error("logits buffer too small");
|
||||
}
|
||||
|
||||
if (logits_size) {
|
||||
io.read_to(this->logits.data, logits_size * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
// read embeddings
|
||||
{
|
||||
LLAMA_LOG_DEBUG("%s: - reading embeddings\n", __func__);
|
||||
|
||||
uint64_t embd_size;
|
||||
io.read_to(&embd_size, sizeof(embd_size));
|
||||
|
||||
if (this->embd.size < embd_size) {
|
||||
throw std::runtime_error("embeddings buffer too small");
|
||||
}
|
||||
|
||||
if (embd_size) {
|
||||
io.read_to(this->embd.data, embd_size * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: handle sampling buffers and samplers state ?
|
||||
// https://github.com/ggml-org/llama.cpp/pull/17004
|
||||
|
||||
if (memory) {
|
||||
LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__);
|
||||
|
||||
|
|
|
|||
|
|
@ -1703,8 +1703,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
} break;
|
||||
case LLM_ARCH_DEEPSEEK2:
|
||||
{
|
||||
// lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B
|
||||
const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26);
|
||||
// lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B, Kanana-2-30B-A3B
|
||||
const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26 || (hparams.n_layer == 48 && n_vocab == 128256));
|
||||
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
|
||||
|
|
|
|||
|
|
@ -2027,7 +2027,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||
pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2;
|
||||
} else if (
|
||||
tokenizer_pre == "gpt-4o" ||
|
||||
tokenizer_pre == "llama4") {
|
||||
tokenizer_pre == "llama4" ||
|
||||
tokenizer_pre == "kanana2") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O;
|
||||
clean_spaces = false;
|
||||
} else if (
|
||||
|
|
|
|||
|
|
@ -361,7 +361,7 @@ static void test_backend_temp_sampling(const test_params & params) {
|
|||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
|
||||
// Verfify sequence 0
|
||||
// Verify sequence 0
|
||||
{
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(0);
|
||||
int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
|
||||
|
|
@ -379,7 +379,7 @@ static void test_backend_temp_sampling(const test_params & params) {
|
|||
}
|
||||
|
||||
|
||||
// Verfify sequence 1
|
||||
// Verify sequence 1
|
||||
{
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(1);
|
||||
|
||||
|
|
@ -395,7 +395,7 @@ static void test_backend_temp_sampling(const test_params & params) {
|
|||
}
|
||||
}
|
||||
|
||||
// lambda to testing non-positive temperature values.
|
||||
// lambda for testing non-positive temperature values.
|
||||
auto test_argmax_temp = [&](float temp) {
|
||||
printf("\nTesting temperature = %.1f\n", temp);
|
||||
|
||||
|
|
@ -454,7 +454,7 @@ static void test_backend_temp_ext_sampling(const test_params & params) {
|
|||
}
|
||||
}
|
||||
|
||||
// lambda to testing non-positive temp/delta/exponent values.
|
||||
// lambda for testing non-positive temp/delta/exponent values.
|
||||
auto test_argmax_temp = [&](float temp, float delta, float exponent) {
|
||||
printf("\nTesting temperature = %.1f, delta = %1.f, exponent = %1.f\n", temp, delta, exponent);
|
||||
|
||||
|
|
@ -530,7 +530,7 @@ static void test_backend_min_p_sampling(const test_params & params) {
|
|||
printf("min-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
|
||||
// Decode and sampler 10 more tokens
|
||||
// Decode and sample 10 more tokens
|
||||
for (int i = 0; i < 10; i++) {
|
||||
int32_t loop_idx = test_ctx.idx_for_seq(seq_id);
|
||||
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), loop_idx);
|
||||
|
|
@ -582,7 +582,7 @@ static void test_backend_top_p_sampling(const test_params & params) {
|
|||
printf("top-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
|
||||
// Decode and sampler 10 more tokens
|
||||
// Decode and sample 10 more tokens
|
||||
for (int i = 0; i < 10; i++) {
|
||||
int32_t loop_idx = test_ctx.idx_for_seq(seq_id);
|
||||
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), loop_idx);
|
||||
|
|
@ -619,7 +619,7 @@ static void test_backend_multi_sequence_sampling(const test_params & params) {
|
|||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
|
||||
// Verfiy sequence 0
|
||||
// Verify sequence 0
|
||||
{
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(0);
|
||||
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
|
||||
|
|
@ -763,7 +763,7 @@ static void test_backend_logit_bias_sampling(const test_params & params) {
|
|||
printf("backend logit bias sampling test PASSED\n");
|
||||
}
|
||||
|
||||
// This test verifies that it is possible to have two different backend sampler,
|
||||
// This test verifies that it is possible to have two different backend samplers,
|
||||
// one that uses the backend dist sampler, and another that uses CPU dist sampler.
|
||||
static void test_backend_mixed_sampling(const test_params & params) {
|
||||
struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params();
|
||||
|
|
@ -791,7 +791,7 @@ static void test_backend_mixed_sampling(const test_params & params) {
|
|||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
|
||||
// Verfiy sequence 0 that used the dist backend sampler.
|
||||
// Verify sequence 0 that used the dist backend sampler.
|
||||
{
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(0);
|
||||
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
|
||||
|
|
@ -802,7 +802,7 @@ static void test_backend_mixed_sampling(const test_params & params) {
|
|||
//GGML_ASSERT(llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx) == 0);
|
||||
}
|
||||
|
||||
// Verfiy sequence 1 that used the top-k backend sampler.
|
||||
// Verify sequence 1 that used the top-k backend sampler.
|
||||
{
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(1);
|
||||
float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx);
|
||||
|
|
@ -934,7 +934,7 @@ static void test_backend_cpu_mixed_batch(const test_params & params) {
|
|||
// samplers.
|
||||
llama_set_sampler(test_ctx.ctx.get(), 0, nullptr);
|
||||
|
||||
// Create a CPU sampler and verify we can sampler from it.
|
||||
// Create a CPU sampler and verify we can sample from it.
|
||||
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
|
||||
llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
|
||||
llama_sampler_chain_add(chain.get(), llama_sampler_init_greedy());
|
||||
|
|
|
|||
|
|
@ -229,6 +229,20 @@ common_chat_tool python_tool {
|
|||
"required": ["code"]
|
||||
})",
|
||||
};
|
||||
common_chat_tool todo_list_tool {
|
||||
/* .name = */ "todo_list",
|
||||
/* .description = */ "Create or update the todo list",
|
||||
/* .parameters = */ R"({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"todos": {
|
||||
"type": "array",
|
||||
"description": "List of TODO list items"
|
||||
}
|
||||
},
|
||||
"required": ["todos"]
|
||||
})",
|
||||
};
|
||||
common_chat_tool code_interpreter_tool {
|
||||
/* .name = */ "code_interpreter",
|
||||
/* .description = */ "an ipython interpreter",
|
||||
|
|
@ -3018,542 +3032,6 @@ Hey there!<|im_end|>
|
|||
);
|
||||
}
|
||||
|
||||
// Test Qwen3-Coder XML format
|
||||
{
|
||||
// Basic XML tool call parsing
|
||||
assert_msg_equals(
|
||||
message_assist_call,
|
||||
test_chat_parse(
|
||||
"<tool_call>\n"
|
||||
" <function=special_function>\n"
|
||||
" <parameter=arg1>\n"
|
||||
" 1\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
/* is_partial= */ false,
|
||||
{COMMON_CHAT_FORMAT_QWEN3_CODER_XML}));
|
||||
|
||||
// Multiple parameters with different types
|
||||
common_chat_msg expected_multi_param;
|
||||
expected_multi_param.role = "assistant";
|
||||
expected_multi_param.tool_calls = {
|
||||
{ "complex_function", "{\"name\":\"John Doe\",\"age\":30,\"active\":true,\"score\":95.5}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(expected_multi_param,
|
||||
"<tool_call>\n"
|
||||
" <function=complex_function>\n"
|
||||
" <parameter=name>\n"
|
||||
" John Doe\n"
|
||||
" </parameter>\n"
|
||||
" <parameter=age>\n"
|
||||
" 30\n"
|
||||
" </parameter>\n"
|
||||
" <parameter=active>\n"
|
||||
" true\n"
|
||||
" </parameter>\n"
|
||||
" <parameter=score>\n"
|
||||
" 95.5\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Special characters and Unicode
|
||||
common_chat_msg expected_special_chars;
|
||||
expected_special_chars.role = "assistant";
|
||||
expected_special_chars.tool_calls = {
|
||||
{ "unicode_function", "{\"message\":\"Hello 世界! 🌍 Special chars: @#$%^&*()\"}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(expected_special_chars,
|
||||
"<tool_call>\n"
|
||||
" <function=unicode_function>\n"
|
||||
" <parameter=message>\n"
|
||||
" Hello 世界! 🌍 Special chars: @#$%^&*()\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Multiline content with newlines and indentation
|
||||
common_chat_msg expected_multiline;
|
||||
expected_multiline.role = "assistant";
|
||||
expected_multiline.tool_calls = {
|
||||
{ "code_function", "{\"code\":\"def hello():\\n print(\\\"Hello, World!\\\")\\n return True\"}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(expected_multiline,
|
||||
"<tool_call>\n"
|
||||
" <function=code_function>\n"
|
||||
" <parameter=code>\n"
|
||||
"def hello():\n"
|
||||
" print(\"Hello, World!\")\n"
|
||||
" return True\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// JSON object as parameter value
|
||||
common_chat_msg expected_json_param;
|
||||
expected_json_param.role = "assistant";
|
||||
expected_json_param.tool_calls = {
|
||||
{ "json_function", "{\"config\":{\"host\":\"localhost\",\"port\":8080,\"ssl\":false}}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_json_param,
|
||||
"<tool_call>\n"
|
||||
" <function=json_function>\n"
|
||||
" <parameter=config>\n"
|
||||
" {\"host\": \"localhost\", \"port\": 8080, \"ssl\": false}\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Array as parameter value
|
||||
common_chat_msg expected_array_param;
|
||||
expected_array_param.role = "assistant";
|
||||
expected_array_param.tool_calls = {
|
||||
{ "array_function", "{\"items\":[\"apple\",\"banana\",\"cherry\"]}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_array_param,
|
||||
"<tool_call>\n"
|
||||
" <function=array_function>\n"
|
||||
" <parameter=items>\n"
|
||||
" [\"apple\", \"banana\", \"cherry\"]\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Empty parameter
|
||||
common_chat_msg expected_empty_param;
|
||||
expected_empty_param.role = "assistant";
|
||||
expected_empty_param.tool_calls = {
|
||||
{ "empty_function", "{\"empty_param\":\"\"}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_empty_param,
|
||||
"<tool_call>\n"
|
||||
" <function=empty_function>\n"
|
||||
" <parameter=empty_param>\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Boolean values (true/false)
|
||||
common_chat_msg expected_boolean;
|
||||
expected_boolean.role = "assistant";
|
||||
expected_boolean.tool_calls = {
|
||||
{ "boolean_function", "{\"enabled\":true,\"debug\":false}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_boolean,
|
||||
"<tool_call>\n"
|
||||
" <function=boolean_function>\n"
|
||||
" <parameter=enabled>\n"
|
||||
" true\n"
|
||||
" </parameter>\n"
|
||||
" <parameter=debug>\n"
|
||||
" false\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Null value
|
||||
common_chat_msg expected_null;
|
||||
expected_null.role = "assistant";
|
||||
expected_null.tool_calls = {
|
||||
{ "null_function", "{\"optional_param\":null}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_null,
|
||||
"<tool_call>\n"
|
||||
" <function=null_function>\n"
|
||||
" <parameter=optional_param>\n"
|
||||
" null\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Negative numbers and scientific notation
|
||||
common_chat_msg expected_numbers;
|
||||
expected_numbers.role = "assistant";
|
||||
expected_numbers.tool_calls = {
|
||||
{ "math_function", "{\"negative\":-42,\"decimal\":-3.14,\"scientific\":1.23e-4}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_numbers,
|
||||
"<tool_call>\n"
|
||||
" <function=math_function>\n"
|
||||
" <parameter=negative>\n"
|
||||
" -42\n"
|
||||
" </parameter>\n"
|
||||
" <parameter=decimal>\n"
|
||||
" -3.14\n"
|
||||
" </parameter>\n"
|
||||
" <parameter=scientific>\n"
|
||||
" 1.23e-4\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// XML-like content in parameters (should be escaped)
|
||||
common_chat_msg expected_xml_content;
|
||||
expected_xml_content.role = "assistant";
|
||||
expected_xml_content.tool_calls = {
|
||||
{ "xml_function", "{\"xml_content\":\"<root><item>value</item></root>\"}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_xml_content,
|
||||
"<tool_call>\n"
|
||||
" <function=xml_function>\n"
|
||||
" <parameter=xml_content>\n"
|
||||
" <root><item>value</item></root>\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Quotes and escape characters
|
||||
common_chat_msg expected_quotes;
|
||||
expected_quotes.role = "assistant";
|
||||
expected_quotes.tool_calls = {
|
||||
{ "quote_function", "{\"message\":\"She said \\\"Hello!\\\" and left.\"}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_quotes,
|
||||
"<tool_call>\n"
|
||||
" <function=quote_function>\n"
|
||||
" <parameter=message>\n"
|
||||
" She said \"Hello!\" and left.\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Long parameter value (simplified)
|
||||
std::string long_text = "This is a long text parameter that should test the parser's ability to handle larger amounts of text data.";
|
||||
|
||||
common_chat_msg expected_long_text;
|
||||
expected_long_text.role = "assistant";
|
||||
expected_long_text.tool_calls = {
|
||||
{ "long_function", "{\"long_text\":\"" + long_text + "\"}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_long_text,
|
||||
"<tool_call>\n"
|
||||
" <function=long_function>\n"
|
||||
" <parameter=long_text>\n"
|
||||
" " + long_text + "\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Mixed content with text before and after tool call
|
||||
common_chat_msg expected_mixed_content;
|
||||
expected_mixed_content.role = "assistant";
|
||||
expected_mixed_content.content = "I'll help you search for products. ";
|
||||
expected_mixed_content.tool_calls = {
|
||||
{ "search_function", "{\"query\":\"laptops\"}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_mixed_content,
|
||||
"I'll help you search for products. <tool_call>\n"
|
||||
" <function=search_function>\n"
|
||||
" <parameter=query>\n"
|
||||
" laptops\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Compact format (no extra whitespace)
|
||||
common_chat_msg expected_compact;
|
||||
expected_compact.role = "assistant";
|
||||
expected_compact.tool_calls = {
|
||||
{ "compact_function", "{\"param\":\"value\"}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_compact,
|
||||
"<tool_call><function=compact_function><parameter=param>value</parameter></function></tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Function name with underscores and numbers
|
||||
common_chat_msg expected_complex_name;
|
||||
expected_complex_name.role = "assistant";
|
||||
expected_complex_name.tool_calls = {
|
||||
{ "get_user_data_v2", "{\"user_id\":12345}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_complex_name,
|
||||
"<tool_call>\n"
|
||||
" <function=get_user_data_v2>\n"
|
||||
" <parameter=user_id>\n"
|
||||
" 12345\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Parameter names with underscores and numbers
|
||||
common_chat_msg expected_complex_params;
|
||||
expected_complex_params.role = "assistant";
|
||||
expected_complex_params.tool_calls = {
|
||||
{ "test_function", "{\"param_1\":\"value1\",\"param_2_name\":\"value2\",\"param3\":123}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_complex_params,
|
||||
"<tool_call>\n"
|
||||
" <function=test_function>\n"
|
||||
" <parameter=param_1>\n"
|
||||
" value1\n"
|
||||
" </parameter>\n"
|
||||
" <parameter=param_2_name>\n"
|
||||
" value2\n"
|
||||
" </parameter>\n"
|
||||
" <parameter=param3>\n"
|
||||
" 123\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Very deeply nested XML content in parameter
|
||||
common_chat_msg expected_deep_xml;
|
||||
expected_deep_xml.role = "assistant";
|
||||
expected_deep_xml.tool_calls = {
|
||||
{ "xml_parser", "{\"xml\":\"<root><level1><level2><level3>deep content</level3></level2></level1></root>\"}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_deep_xml,
|
||||
"<tool_call>\n"
|
||||
" <function=xml_parser>\n"
|
||||
" <parameter=xml>\n"
|
||||
" <root><level1><level2><level3>deep content</level3></level2></level1></root>\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Parameter with only whitespace
|
||||
common_chat_msg expected_whitespace_param;
|
||||
expected_whitespace_param.role = "assistant";
|
||||
expected_whitespace_param.tool_calls = {
|
||||
{ "whitespace_function", "{\"spaces\":\"\"}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_whitespace_param,
|
||||
"<tool_call>\n"
|
||||
" <function=whitespace_function>\n"
|
||||
" <parameter=spaces>\n"
|
||||
" \n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Parameter with tabs and mixed whitespace
|
||||
common_chat_msg expected_mixed_whitespace;
|
||||
expected_mixed_whitespace.role = "assistant";
|
||||
expected_mixed_whitespace.tool_calls = {
|
||||
{ "tab_function", "{\"content\":\"line1\\n\\tindented line\\n spaces\"}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_mixed_whitespace,
|
||||
"<tool_call>\n"
|
||||
" <function=tab_function>\n"
|
||||
" <parameter=content>\n"
|
||||
"line1\n"
|
||||
"\tindented line\n"
|
||||
" spaces\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Control characters and special Unicode
|
||||
common_chat_msg expected_control_chars;
|
||||
expected_control_chars.role = "assistant";
|
||||
expected_control_chars.tool_calls = {
|
||||
{ "control_function", "{\"text\":\"Line1\\nLine2\\tTabbed\\rCarriage return\"}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_control_chars,
|
||||
"<tool_call>\n"
|
||||
" <function=control_function>\n"
|
||||
" <parameter=text>\n"
|
||||
"Line1\nLine2\tTabbed\rCarriage return\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Emoji and extended Unicode characters
|
||||
common_chat_msg expected_emoji;
|
||||
expected_emoji.role = "assistant";
|
||||
expected_emoji.tool_calls = {
|
||||
{ "emoji_function", "{\"message\":\"Hello! 👋 🌟 🚀 Testing emojis: 😀😃😄😁 and symbols: ∑∏∆∇\"}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_emoji,
|
||||
"<tool_call>\n"
|
||||
" <function=emoji_function>\n"
|
||||
" <parameter=message>\n"
|
||||
" Hello! 👋 🌟 🚀 Testing emojis: 😀😃😄😁 and symbols: ∑∏∆∇\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Mathematical expressions and formulas
|
||||
common_chat_msg expected_math;
|
||||
expected_math.role = "assistant";
|
||||
expected_math.tool_calls = {
|
||||
{ "math_function", "{\"formula\":\"E = mc² and ∫f(x)dx = F(x) + C\"}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_math,
|
||||
"<tool_call>\n"
|
||||
" <function=math_function>\n"
|
||||
" <parameter=formula>\n"
|
||||
" E = mc² and ∫f(x)dx = F(x) + C\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// SQL injection-like content (should be safely escaped)
|
||||
common_chat_msg expected_sql;
|
||||
expected_sql.role = "assistant";
|
||||
expected_sql.tool_calls = {
|
||||
{ "sql_function", "{\"query\":\"SELECT * FROM users WHERE id = 1; DROP TABLE users; --\"}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_sql,
|
||||
"<tool_call>\n"
|
||||
" <function=sql_function>\n"
|
||||
" <parameter=query>\n"
|
||||
" SELECT * FROM users WHERE id = 1; DROP TABLE users; --\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// HTML/XML injection content
|
||||
common_chat_msg expected_html;
|
||||
expected_html.role = "assistant";
|
||||
expected_html.tool_calls = {
|
||||
{ "html_function", "{\"content\":\"<script>alert('xss')</script><img src=x onerror=alert(1)>\"}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_html,
|
||||
"<tool_call>\n"
|
||||
" <function=html_function>\n"
|
||||
" <parameter=content>\n"
|
||||
" <script>alert('xss')</script><img src=x onerror=alert(1)>\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Binary-like content (base64)
|
||||
common_chat_msg expected_binary;
|
||||
expected_binary.role = "assistant";
|
||||
expected_binary.tool_calls = {
|
||||
{ "binary_function", "{\"data\":\"SGVsbG8gV29ybGQhIFRoaXMgaXMgYmFzZTY0IGVuY29kZWQgdGV4dC4=\"}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_binary,
|
||||
"<tool_call>\n"
|
||||
" <function=binary_function>\n"
|
||||
" <parameter=data>\n"
|
||||
" SGVsbG8gV29ybGQhIFRoaXMgaXMgYmFzZTY0IGVuY29kZWQgdGV4dC4=\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Very large numbers (should be parsed as scientific notation)
|
||||
common_chat_msg expected_large_numbers;
|
||||
expected_large_numbers.role = "assistant";
|
||||
expected_large_numbers.tool_calls = {
|
||||
{ "number_function", "{\"big_int\":1e+60}", "" } // Large number becomes scientific notation
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_large_numbers,
|
||||
"<tool_call>\n"
|
||||
" <function=number_function>\n"
|
||||
" <parameter=big_int>\n"
|
||||
" 999999999999999999999999999999999999999999999999999999999999\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
}
|
||||
|
||||
{
|
||||
// Qwen3-Coder template
|
||||
auto tmpls = read_templates("models/templates/Qwen3-Coder.jinja");
|
||||
common_chat_templates_inputs inputs;
|
||||
inputs.messages = { message_user };
|
||||
|
||||
common_chat_tool qwen_union_tool {
|
||||
/* .name = */ "qwen_union",
|
||||
/* .description = */ "Test tool for union/anyOf handling",
|
||||
/* .parameters = */ R"({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"priority": { "type": ["number", "null"] },
|
||||
"maybe_text": { "anyOf": [ { "type": "string" } ] },
|
||||
"config": { "anyOf": [ { "type": "object" }, { "type": "null" } ] }
|
||||
},
|
||||
"required": []
|
||||
})",
|
||||
};
|
||||
inputs.tools = { qwen_union_tool };
|
||||
|
||||
auto params = common_chat_templates_apply(tmpls.get(), inputs);
|
||||
assert_equals(COMMON_CHAT_FORMAT_QWEN3_CODER_XML, params.format);
|
||||
assert_equals(false, params.grammar.empty());
|
||||
|
||||
// Grammar should compile successfully
|
||||
auto grammar = build_grammar(params.grammar);
|
||||
GGML_ASSERT(grammar && "Failed to build Qwen3-Coder grammar with union types");
|
||||
}
|
||||
|
||||
{
|
||||
// Step-3.5-Flash template: uses same XML output format as Qwen3-Coder and Nemotron v3,
|
||||
// but with <think> support. Routes to the Nemotron v3 PEG parser for streaming and
|
||||
|
|
@ -3665,6 +3143,135 @@ static void test_template_output_peg_parsers() {
|
|||
});
|
||||
}
|
||||
|
||||
{
|
||||
// Qwen3-Coder
|
||||
auto tmpls = read_templates("models/templates/Qwen3-Coder.jinja");
|
||||
|
||||
// Test basic message
|
||||
test_peg_parser(tmpls.get(), [&](auto & t) {
|
||||
t.input = "Hello, world!\nWhat's up?";
|
||||
t.expect = message_assist;
|
||||
});
|
||||
|
||||
// Test tool call
|
||||
test_peg_parser(tmpls.get(), [&](auto & t) {
|
||||
t.input =
|
||||
"<tool_call>\n"
|
||||
"<function=special_function>\n"
|
||||
"<parameter=arg1>\n"
|
||||
"1\n"
|
||||
"</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call>";
|
||||
t.params.tools = {special_function_tool};
|
||||
t.expect = message_assist_call;
|
||||
});
|
||||
|
||||
// Test parallel tool calls
|
||||
test_peg_parser(tmpls.get(), [&](auto & t) {
|
||||
t.input =
|
||||
"<tool_call>\n"
|
||||
"<function=special_function>\n"
|
||||
"<parameter=arg1>\n"
|
||||
"1\n"
|
||||
"</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call>\n"
|
||||
"<tool_call>\n"
|
||||
"<function=special_function_with_opt>\n"
|
||||
"<parameter=arg1>\n"
|
||||
"1\n"
|
||||
"</parameter>\n"
|
||||
"<parameter=arg2>\n"
|
||||
"2\n"
|
||||
"</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call>";
|
||||
t.params.parallel_tool_calls = true;
|
||||
t.params.tools = {special_function_tool, special_function_tool_with_optional_param};
|
||||
|
||||
t.expect.tool_calls = {{
|
||||
/* .name = */ "special_function",
|
||||
/* .arguments = */ R"({"arg1": 1})",
|
||||
/* .id = */ {},
|
||||
}, {
|
||||
/* .name = */ "special_function_with_opt",
|
||||
/* .arguments = */ R"({"arg1": 1, "arg2": 2})",
|
||||
/* .id = */ {},
|
||||
}};
|
||||
});
|
||||
|
||||
// Test tool call with string parameter
|
||||
test_peg_parser(tmpls.get(), [&](auto & t) {
|
||||
t.input =
|
||||
"<tool_call>\n"
|
||||
"<function=python>\n"
|
||||
"<parameter=code>\n"
|
||||
"def hello():\n"
|
||||
" print(\"Hello, world!\")\n"
|
||||
"\n"
|
||||
"hello()\n"
|
||||
"</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call>";
|
||||
t.params.tools = {python_tool};
|
||||
|
||||
t.expect.tool_calls = {{
|
||||
/* .name = */ "python",
|
||||
/* .arguments = */ "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}",
|
||||
/* .id = */ {},
|
||||
}};
|
||||
});
|
||||
|
||||
// Test tool call with JSON parameter
|
||||
test_peg_parser(tmpls.get(), [&](auto & t) {
|
||||
t.input =
|
||||
"<tool_call>\n"
|
||||
"<function=todo_list>\n"
|
||||
"<parameter=todos>\n"
|
||||
"[{\"item\": \"Check stuff\", \"selected\": false}, {\"item\": \"Prepare stuff\", \"selected\": true}]\n"
|
||||
"</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call>";
|
||||
t.params.tools = {todo_list_tool};
|
||||
|
||||
t.expect.tool_calls = {{
|
||||
/* .name = */ "todo_list",
|
||||
/* .arguments = */ "{\"todos\": [{\"item\": \"Check stuff\", \"selected\": false}, {\"item\": \"Prepare stuff\", \"selected\": true}]}",
|
||||
/* .id = */ {},
|
||||
}};
|
||||
});
|
||||
|
||||
// Test tool call with string parameter and no closing </parameter> tag
|
||||
test_peg_parser(tmpls.get(), [&](auto & t) {
|
||||
t.input =
|
||||
"<tool_call>\n"
|
||||
"<function=python>\n"
|
||||
"<parameter=code>\n"
|
||||
"def hello():\n"
|
||||
" print(\"Hello, world!\")\n"
|
||||
"\n"
|
||||
"hello()\n"
|
||||
"</function>\n"
|
||||
"</tool_call>";
|
||||
t.params.tools = {python_tool};
|
||||
|
||||
t.expect.tool_calls = {{
|
||||
/* .name = */ "python",
|
||||
/* .arguments = */ "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}",
|
||||
/* .id = */ {},
|
||||
}};
|
||||
});
|
||||
|
||||
// Test response format
|
||||
test_peg_parser(tmpls.get(), [&](auto & t) {
|
||||
t.input = R"({"amount": 123.45, "date": "2025-12-03"})";
|
||||
t.params.json_schema = invoice_schema;
|
||||
|
||||
t.expect.content = R"({"amount": 123.45, "date": "2025-12-03"})";
|
||||
});
|
||||
}
|
||||
|
||||
{
|
||||
// NVIDIA Nemotron-3 Nano
|
||||
auto tmpls = read_templates("models/templates/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16.jinja");
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ static void test_string_methods(testing & t);
|
|||
static void test_array_methods(testing & t);
|
||||
static void test_object_methods(testing & t);
|
||||
static void test_hasher(testing & t);
|
||||
static void test_stats(testing & t);
|
||||
static void test_fuzzing(testing & t);
|
||||
|
||||
static bool g_python_mode = false;
|
||||
|
|
@ -70,6 +71,7 @@ int main(int argc, char *argv[]) {
|
|||
t.test("object methods", test_object_methods);
|
||||
if (!g_python_mode) {
|
||||
t.test("hasher", test_hasher);
|
||||
t.test("stats", test_stats);
|
||||
t.test("fuzzing", test_fuzzing);
|
||||
}
|
||||
|
||||
|
|
@ -1795,6 +1797,63 @@ static void test_hasher(testing & t) {
|
|||
});
|
||||
}
|
||||
|
||||
static void test_stats(testing & t) {
|
||||
static auto get_stats = [](const std::string & tmpl, const json & vars) -> jinja::value {
|
||||
jinja::lexer lexer;
|
||||
auto lexer_res = lexer.tokenize(tmpl);
|
||||
|
||||
jinja::program prog = jinja::parse_from_tokens(lexer_res);
|
||||
|
||||
jinja::context ctx(tmpl);
|
||||
jinja::global_from_json(ctx, json{{ "val", vars }}, true);
|
||||
ctx.is_get_stats = true;
|
||||
|
||||
jinja::runtime runtime(ctx);
|
||||
runtime.execute(prog);
|
||||
|
||||
return ctx.get_val("val");
|
||||
};
|
||||
|
||||
t.test("stats", [](testing & t) {
|
||||
jinja::value val = get_stats(
|
||||
"{{val.num}} "
|
||||
"{{val.str}} "
|
||||
"{{val.arr[0]}} "
|
||||
"{{val.obj.key1}} "
|
||||
"{{val.nested | tojson}}",
|
||||
// Note: the json below will be wrapped inside "val" in the context
|
||||
json{
|
||||
{"num", 1},
|
||||
{"str", "abc"},
|
||||
{"arr", json::array({1, 2, 3})},
|
||||
{"obj", json::object({{"key1", 1}, {"key2", 2}, {"key3", 3}})},
|
||||
{"nested", json::object({
|
||||
{"inner_key1", json::array({1, 2})},
|
||||
{"inner_key2", json::object({{"a", "x"}, {"b", "y"}})}
|
||||
})},
|
||||
{"mixed", json::object({
|
||||
{"used", 1},
|
||||
{"unused", 2},
|
||||
})},
|
||||
}
|
||||
);
|
||||
|
||||
t.assert_true("num is used", val->at("num")->stats.used);
|
||||
t.assert_true("str is used", val->at("str")->stats.used);
|
||||
|
||||
t.assert_true("arr is used", val->at("arr")->stats.used);
|
||||
t.assert_true("arr[0] is used", val->at("arr")->at(0)->stats.used);
|
||||
t.assert_true("arr[1] is not used", !val->at("arr")->at(1)->stats.used);
|
||||
|
||||
t.assert_true("obj is used", val->at("obj")->stats.used);
|
||||
t.assert_true("obj.key1 is used", val->at("obj")->at("key1")->stats.used);
|
||||
t.assert_true("obj.key2 is not used", !val->at("obj")->at("key2")->stats.used);
|
||||
|
||||
t.assert_true("inner_key1[0] is used", val->at("nested")->at("inner_key1")->at(0)->stats.used);
|
||||
t.assert_true("inner_key2.a is used", val->at("nested")->at("inner_key2")->at("a")->stats.used);
|
||||
});
|
||||
}
|
||||
|
||||
static void test_template_cpp(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) {
|
||||
t.test(name, [&tmpl, &vars, &expect](testing & t) {
|
||||
jinja::lexer lexer;
|
||||
|
|
|
|||
|
|
@ -380,6 +380,15 @@ int main(int argc, char ** argv) {
|
|||
console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
|
||||
continue;
|
||||
}
|
||||
if (inf.fim_sep_token != LLAMA_TOKEN_NULL) {
|
||||
cur_msg += common_token_to_piece(ctx_cli.ctx_server.get_llama_context(), inf.fim_sep_token, true);
|
||||
cur_msg += fname;
|
||||
cur_msg.push_back('\n');
|
||||
} else {
|
||||
cur_msg += "--- File: ";
|
||||
cur_msg += fname;
|
||||
cur_msg += " ---\n";
|
||||
}
|
||||
cur_msg += marker;
|
||||
console::log("Loaded text from '%s'\n", fname.c_str());
|
||||
continue;
|
||||
|
|
|
|||
|
|
@ -387,6 +387,17 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
session_do_save = !path_session.empty() && n_match < embd_inp.size() && !params.prompt_cache_ro;
|
||||
|
||||
// Logits are not stored as part of the session state so we need to
|
||||
// "replay" the last token to get logits for sampling.
|
||||
if (!session_tokens.empty() && n_match > 0 && n_match == session_tokens.size()) {
|
||||
if (!common_replay_last_token(ctx, session_tokens.back(), n_match)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
session_do_save = false;
|
||||
LOG_INF("%s: replayed last token from session\n", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
// number of tokens to keep when resetting context
|
||||
|
|
@ -675,40 +686,27 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
if (!embd.empty()) {
|
||||
int n_eval = (int) embd.size();
|
||||
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
|
||||
|
||||
GGML_ASSERT(n_eval <= params.n_batch);
|
||||
if (llama_decode(ctx, llama_batch_get_one(embd.data(), n_eval))) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
const bool is_last_batch = (n_consumed >= (int) embd_inp.size());
|
||||
const bool save_now = session_do_save && is_last_batch;
|
||||
if (!common_prompt_batch_decode(ctx, embd, n_past, params.n_batch, path_session, save_now)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
n_past += n_eval;
|
||||
session_tokens.insert(session_tokens.end(), embd.begin(), embd.begin());
|
||||
n_session_consumed = session_tokens.size();
|
||||
session_do_save = false;
|
||||
|
||||
LOG_DBG("n_past = %d\n", n_past);
|
||||
|
||||
// Display total tokens alongside total time
|
||||
if (params.n_print > 0 && n_past % params.n_print == 0) {
|
||||
LOG_DBG("\n\033[31mTokens consumed so far = %d / %d \033[0m\n", n_past, n_ctx);
|
||||
}
|
||||
}
|
||||
|
||||
if (!embd.empty() && !path_session.empty()) {
|
||||
session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
|
||||
n_session_consumed = session_tokens.size();
|
||||
}
|
||||
}
|
||||
|
||||
embd.clear();
|
||||
|
||||
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
|
||||
// optionally save the session on first sample (for faster prompt loading next time)
|
||||
if (session_do_save) {
|
||||
session_do_save = false;
|
||||
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
|
||||
|
||||
LOG_DBG("saved session to %s\n", path_session.c_str());
|
||||
}
|
||||
|
||||
const llama_token id = common_sampler_sample(smpl, ctx, -1);
|
||||
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -1105,6 +1105,8 @@ json convert_responses_to_chatcmpl(const json & response_body) {
|
|||
};
|
||||
|
||||
for (json item : input_value) {
|
||||
bool merge_prev = !chatcmpl_messages.empty() && chatcmpl_messages.back().value("role", "") == "assistant";
|
||||
|
||||
if (exists_and_is_string(item, "content")) {
|
||||
// #responses_create-input-input_item_list-input_message-content-text_input
|
||||
// Only "Input message" contains item["content"]::string
|
||||
|
|
@ -1193,7 +1195,7 @@ json convert_responses_to_chatcmpl(const json & response_body) {
|
|||
item.at("type") == "message"
|
||||
) {
|
||||
// #responses_create-input-input_item_list-item-output_message
|
||||
std::vector<json> chatcmpl_content;
|
||||
auto chatcmpl_content = json::array();
|
||||
|
||||
for (const auto & output_text : item.at("content")) {
|
||||
const std::string type = json_value(output_text, "type", std::string());
|
||||
|
|
@ -1210,10 +1212,19 @@ json convert_responses_to_chatcmpl(const json & response_body) {
|
|||
});
|
||||
}
|
||||
|
||||
item.erase("status");
|
||||
item.erase("type");
|
||||
item["content"] = chatcmpl_content;
|
||||
chatcmpl_messages.push_back(item);
|
||||
if (merge_prev) {
|
||||
auto & prev_msg = chatcmpl_messages.back();
|
||||
if (!exists_and_is_array(prev_msg, "content")) {
|
||||
prev_msg["content"] = json::array();
|
||||
}
|
||||
auto & prev_content = prev_msg["content"];
|
||||
prev_content.insert(prev_content.end(), chatcmpl_content.begin(), chatcmpl_content.end());
|
||||
} else {
|
||||
item.erase("status");
|
||||
item.erase("type");
|
||||
item["content"] = chatcmpl_content;
|
||||
chatcmpl_messages.push_back(item);
|
||||
}
|
||||
} else if (exists_and_is_string(item, "arguments") &&
|
||||
exists_and_is_string(item, "call_id") &&
|
||||
exists_and_is_string(item, "name") &&
|
||||
|
|
@ -1221,24 +1232,27 @@ json convert_responses_to_chatcmpl(const json & response_body) {
|
|||
item.at("type") == "function_call"
|
||||
) {
|
||||
// #responses_create-input-input_item_list-item-function_tool_call
|
||||
json msg = json {
|
||||
{"role", "assistant"},
|
||||
{"tool_calls", json::array({ json {
|
||||
{"function", json {
|
||||
{"arguments", item.at("arguments")},
|
||||
{"name", item.at("name")},
|
||||
}},
|
||||
{"id", item.at("call_id")},
|
||||
{"type", "function"},
|
||||
}})},
|
||||
json tool_call = {
|
||||
{"function", json {
|
||||
{"arguments", item.at("arguments")},
|
||||
{"name", item.at("name")},
|
||||
}},
|
||||
{"id", item.at("call_id")},
|
||||
{"type", "function"},
|
||||
};
|
||||
|
||||
if (!chatcmpl_messages.empty() && chatcmpl_messages.back().contains("reasoning_content")) {
|
||||
// Move reasoning content from dummy message to tool call message
|
||||
msg["reasoning_content"] = chatcmpl_messages.back().at("reasoning_content");
|
||||
chatcmpl_messages.pop_back();
|
||||
if (merge_prev) {
|
||||
auto & prev_msg = chatcmpl_messages.back();
|
||||
if (!exists_and_is_array(prev_msg, "tool_calls")) {
|
||||
prev_msg["tool_calls"] = json::array();
|
||||
}
|
||||
prev_msg["tool_calls"].push_back(tool_call);
|
||||
} else {
|
||||
chatcmpl_messages.push_back(json {
|
||||
{"role", "assistant"},
|
||||
{"tool_calls", json::array({tool_call})}
|
||||
});
|
||||
}
|
||||
chatcmpl_messages.push_back(msg);
|
||||
} else if (exists_and_is_string(item, "call_id") &&
|
||||
(exists_and_is_string(item, "output") || exists_and_is_array(item, "output")) &&
|
||||
exists_and_is_string(item, "type") &&
|
||||
|
|
@ -1282,12 +1296,16 @@ json convert_responses_to_chatcmpl(const json & response_body) {
|
|||
throw std::invalid_argument("item['content']['text'] is not a string");
|
||||
}
|
||||
|
||||
// Pack reasoning content in dummy message
|
||||
chatcmpl_messages.push_back(json {
|
||||
{"role", "assistant"},
|
||||
{"content", json::array()},
|
||||
{"reasoning_content", item.at("content")[0].at("text")},
|
||||
});
|
||||
if (merge_prev) {
|
||||
auto & prev_msg = chatcmpl_messages.back();
|
||||
prev_msg["reasoning_content"] = item.at("content")[0].at("text");
|
||||
} else {
|
||||
chatcmpl_messages.push_back(json {
|
||||
{"role", "assistant"},
|
||||
{"content", json::array()},
|
||||
{"reasoning_content", item.at("content")[0].at("text")},
|
||||
});
|
||||
}
|
||||
} else {
|
||||
throw std::invalid_argument("Cannot determine type of 'item'");
|
||||
}
|
||||
|
|
@ -1296,20 +1314,6 @@ json convert_responses_to_chatcmpl(const json & response_body) {
|
|||
throw std::invalid_argument("'input' must be a string or array of objects");
|
||||
}
|
||||
|
||||
// Remove unused dummy message which contains
|
||||
// reasoning content not followed by tool call
|
||||
chatcmpl_messages.erase(std::remove_if(
|
||||
chatcmpl_messages.begin(),
|
||||
chatcmpl_messages.end(),
|
||||
[](const json & x){ return x.contains("role") &&
|
||||
x.at("role") == "assistant" &&
|
||||
x.contains("content") &&
|
||||
x.at("content") == json::array() &&
|
||||
x.contains("reasoning_content");
|
||||
}),
|
||||
chatcmpl_messages.end()
|
||||
);
|
||||
|
||||
chatcmpl_body["messages"] = chatcmpl_messages;
|
||||
|
||||
if (response_body.contains("tools")) {
|
||||
|
|
|
|||
|
|
@ -2911,6 +2911,9 @@ server_context_meta server_context::get_meta() const {
|
|||
/* fim_pre_token */ llama_vocab_fim_pre(impl->vocab),
|
||||
/* fim_sub_token */ llama_vocab_fim_suf(impl->vocab),
|
||||
/* fim_mid_token */ llama_vocab_fim_mid(impl->vocab),
|
||||
/* fim_pad_token */ llama_vocab_fim_pad(impl->vocab),
|
||||
/* fim_rep_token */ llama_vocab_fim_rep(impl->vocab),
|
||||
/* fim_sep_token */ llama_vocab_fim_sep(impl->vocab),
|
||||
|
||||
/* model_vocab_type */ llama_vocab_type(impl->vocab),
|
||||
/* model_vocab_n_tokens */ llama_vocab_n_tokens(impl->vocab),
|
||||
|
|
|
|||
|
|
@ -30,6 +30,9 @@ struct server_context_meta {
|
|||
llama_token fim_pre_token;
|
||||
llama_token fim_sub_token;
|
||||
llama_token fim_mid_token;
|
||||
llama_token fim_pad_token;
|
||||
llama_token fim_rep_token;
|
||||
llama_token fim_sep_token;
|
||||
|
||||
// model meta
|
||||
enum llama_vocab_type model_vocab_type;
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ In a separate terminal, start the backend server:
|
|||
./llama-server -m model.gguf
|
||||
|
||||
# Multi-model (ROUTER mode)
|
||||
./llama-server --model-store /path/to/models
|
||||
./llama-server --models-dir /path/to/models
|
||||
```
|
||||
|
||||
### 3. Start Development Servers
|
||||
|
|
|
|||
|
|
@ -114,6 +114,11 @@
|
|||
label: 'Render user content as Markdown',
|
||||
type: SettingsFieldType.CHECKBOX
|
||||
},
|
||||
{
|
||||
key: SETTINGS_KEYS.FULL_HEIGHT_CODE_BLOCKS,
|
||||
label: 'Use full height code blocks',
|
||||
type: SettingsFieldType.CHECKBOX
|
||||
},
|
||||
{
|
||||
key: SETTINGS_KEYS.DISABLE_AUTO_SCROLL,
|
||||
label: 'Disable automatic scroll',
|
||||
|
|
|
|||
|
|
@ -38,6 +38,8 @@
|
|||
import { ActionIconsCodeBlock, DialogCodePreview } from '$lib/components/app';
|
||||
import { createAutoScrollController } from '$lib/hooks/use-auto-scroll.svelte';
|
||||
import type { DatabaseMessageExtra } from '$lib/types/database';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { SETTINGS_KEYS } from '$lib/constants/settings-keys';
|
||||
|
||||
interface Props {
|
||||
attachments?: DatabaseMessageExtra[];
|
||||
|
|
@ -593,7 +595,12 @@
|
|||
});
|
||||
</script>
|
||||
|
||||
<div bind:this={containerRef} class={className}>
|
||||
<div
|
||||
bind:this={containerRef}
|
||||
class="{className}{config()[SETTINGS_KEYS.FULL_HEIGHT_CODE_BLOCKS]
|
||||
? ' full-height-code-blocks'
|
||||
: ''}"
|
||||
>
|
||||
{#each renderedBlocks as block (block.id)}
|
||||
<div class="markdown-block" data-block-id={block.id}>
|
||||
<!-- eslint-disable-next-line no-at-html-tags -->
|
||||
|
|
@ -914,6 +921,16 @@
|
|||
line-height: 1.3;
|
||||
}
|
||||
|
||||
.full-height-code-blocks :global(.code-block-wrapper) {
|
||||
max-height: none;
|
||||
}
|
||||
|
||||
.full-height-code-blocks :global(.code-block-scroll-container),
|
||||
.full-height-code-blocks .streaming-code-scroll-container {
|
||||
max-height: none;
|
||||
overflow-y: visible;
|
||||
}
|
||||
|
||||
div :global(.code-block-header) {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
|
|
|
|||
|
|
@ -251,9 +251,6 @@
|
|||
return options.find((option) => option.id === activeId);
|
||||
}
|
||||
|
||||
if (options.length === 1) {
|
||||
return options[0];
|
||||
}
|
||||
// No selection - return undefined to show "Select model"
|
||||
return undefined;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean> =
|
|||
alwaysShowSidebarOnDesktop: false,
|
||||
autoShowSidebarOnNewChat: true,
|
||||
autoMicOnEmpty: false,
|
||||
fullHeightCodeBlocks: false,
|
||||
// make sure these default values are in sync with `common.h`
|
||||
samplers: 'top_k;typ_p;top_p;min_p;temperature',
|
||||
backend_sampling: false,
|
||||
|
|
@ -113,6 +114,8 @@ export const SETTING_CONFIG_INFO: Record<string, string> = {
|
|||
'Automatically show sidebar when starting a new chat. Disable to keep the sidebar hidden until you click on it.',
|
||||
autoMicOnEmpty:
|
||||
'Automatically show microphone button instead of send button when textarea is empty for models with audio modality support.',
|
||||
fullHeightCodeBlocks:
|
||||
'Always display code blocks at their full natural height, overriding any height limits.',
|
||||
pyInterpreterEnabled:
|
||||
'Enable Python interpreter using Pyodide. Allows running Python code in markdown code blocks.',
|
||||
enableContinueGeneration:
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ export const SETTINGS_KEYS = {
|
|||
DISABLE_AUTO_SCROLL: 'disableAutoScroll',
|
||||
ALWAYS_SHOW_SIDEBAR_ON_DESKTOP: 'alwaysShowSidebarOnDesktop',
|
||||
AUTO_SHOW_SIDEBAR_ON_NEW_CHAT: 'autoShowSidebarOnNewChat',
|
||||
FULL_HEIGHT_CODE_BLOCKS: 'fullHeightCodeBlocks',
|
||||
// Sampling
|
||||
TEMPERATURE: 'temperature',
|
||||
DYNATEMP_RANGE: 'dynatemp_range',
|
||||
|
|
|
|||
|
|
@ -153,6 +153,12 @@ export const SYNCABLE_PARAMETERS: SyncableParameter[] = [
|
|||
serverKey: 'enableContinueGeneration',
|
||||
type: SyncableParameterType.BOOLEAN,
|
||||
canSync: true
|
||||
},
|
||||
{
|
||||
key: 'fullHeightCodeBlocks',
|
||||
serverKey: 'fullHeightCodeBlocks',
|
||||
type: SyncableParameterType.BOOLEAN,
|
||||
canSync: true
|
||||
}
|
||||
];
|
||||
|
||||
|
|
|
|||
|
|
@ -306,6 +306,16 @@ class ModelsStore {
|
|||
const response = await ModelsService.listRouter();
|
||||
this.routerModels = response.data;
|
||||
await this.fetchModalitiesForLoadedModels();
|
||||
|
||||
const o = this.models.filter((option) => {
|
||||
const modelProps = this.getModelProps(option.model);
|
||||
|
||||
return modelProps?.webui !== false;
|
||||
});
|
||||
|
||||
if (o.length === 1 && this.isModelLoaded(o[0].model)) {
|
||||
this.selectModelById(o[0].id);
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn('Failed to fetch router models:', error);
|
||||
this.routerModels = [];
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -8,8 +8,8 @@
|
|||
#ifndef CPPHTTPLIB_HTTPLIB_H
|
||||
#define CPPHTTPLIB_HTTPLIB_H
|
||||
|
||||
#define CPPHTTPLIB_VERSION "0.32.0"
|
||||
#define CPPHTTPLIB_VERSION_NUM "0x002000"
|
||||
#define CPPHTTPLIB_VERSION "0.34.0"
|
||||
#define CPPHTTPLIB_VERSION_NUM "0x002200"
|
||||
|
||||
/*
|
||||
* Platform compatibility check
|
||||
|
|
@ -185,6 +185,14 @@
|
|||
: 0))
|
||||
#endif
|
||||
|
||||
#ifndef CPPHTTPLIB_THREAD_POOL_MAX_COUNT
|
||||
#define CPPHTTPLIB_THREAD_POOL_MAX_COUNT (CPPHTTPLIB_THREAD_POOL_COUNT * 4)
|
||||
#endif
|
||||
|
||||
#ifndef CPPHTTPLIB_THREAD_POOL_IDLE_TIMEOUT
|
||||
#define CPPHTTPLIB_THREAD_POOL_IDLE_TIMEOUT 3 // seconds
|
||||
#endif
|
||||
|
||||
#ifndef CPPHTTPLIB_RECV_FLAGS
|
||||
#define CPPHTTPLIB_RECV_FLAGS 0
|
||||
#endif
|
||||
|
|
@ -201,6 +209,22 @@
|
|||
#define CPPHTTPLIB_MAX_LINE_LENGTH 32768
|
||||
#endif
|
||||
|
||||
#ifndef CPPHTTPLIB_WEBSOCKET_MAX_PAYLOAD_LENGTH
|
||||
#define CPPHTTPLIB_WEBSOCKET_MAX_PAYLOAD_LENGTH 16777216
|
||||
#endif
|
||||
|
||||
#ifndef CPPHTTPLIB_WEBSOCKET_READ_TIMEOUT_SECOND
|
||||
#define CPPHTTPLIB_WEBSOCKET_READ_TIMEOUT_SECOND 300
|
||||
#endif
|
||||
|
||||
#ifndef CPPHTTPLIB_WEBSOCKET_CLOSE_TIMEOUT_SECOND
|
||||
#define CPPHTTPLIB_WEBSOCKET_CLOSE_TIMEOUT_SECOND 5
|
||||
#endif
|
||||
|
||||
#ifndef CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND
|
||||
#define CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND 30
|
||||
#endif
|
||||
|
||||
/*
|
||||
* Headers
|
||||
*/
|
||||
|
|
@ -310,6 +334,7 @@ using socket_t = int;
|
|||
#include <errno.h>
|
||||
#include <exception>
|
||||
#include <fcntl.h>
|
||||
#include <fstream>
|
||||
#include <functional>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
|
|
@ -328,6 +353,9 @@ using socket_t = int;
|
|||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#if __cplusplus >= 201703L
|
||||
#include <any>
|
||||
#endif
|
||||
|
||||
#if defined(CPPHTTPLIB_USE_NON_BLOCKING_GETADDRINFO) || \
|
||||
defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN)
|
||||
|
|
@ -415,10 +443,46 @@ using socket_t = int;
|
|||
|
||||
#endif // CPPHTTPLIB_MBEDTLS_SUPPORT
|
||||
|
||||
#ifdef CPPHTTPLIB_WOLFSSL_SUPPORT
|
||||
#include <wolfssl/options.h>
|
||||
|
||||
#include <wolfssl/openssl/x509v3.h>
|
||||
|
||||
// Fallback definitions for older wolfSSL versions (e.g., 5.6.6)
|
||||
#ifndef WOLFSSL_GEN_EMAIL
|
||||
#define WOLFSSL_GEN_EMAIL 1
|
||||
#endif
|
||||
#ifndef WOLFSSL_GEN_DNS
|
||||
#define WOLFSSL_GEN_DNS 2
|
||||
#endif
|
||||
#ifndef WOLFSSL_GEN_URI
|
||||
#define WOLFSSL_GEN_URI 6
|
||||
#endif
|
||||
#ifndef WOLFSSL_GEN_IPADD
|
||||
#define WOLFSSL_GEN_IPADD 7
|
||||
#endif
|
||||
|
||||
#include <wolfssl/ssl.h>
|
||||
#include <wolfssl/wolfcrypt/hash.h>
|
||||
#include <wolfssl/wolfcrypt/md5.h>
|
||||
#include <wolfssl/wolfcrypt/sha256.h>
|
||||
#include <wolfssl/wolfcrypt/sha512.h>
|
||||
#ifdef _WIN32
|
||||
#include <wincrypt.h>
|
||||
#ifdef _MSC_VER
|
||||
#pragma comment(lib, "crypt32.lib")
|
||||
#endif
|
||||
#endif // _WIN32
|
||||
#if defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN)
|
||||
#if TARGET_OS_MAC
|
||||
#include <Security/Security.h>
|
||||
#endif
|
||||
#endif // CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN
|
||||
#endif // CPPHTTPLIB_WOLFSSL_SUPPORT
|
||||
|
||||
// Define CPPHTTPLIB_SSL_ENABLED if any SSL backend is available
|
||||
// This simplifies conditional compilation when adding new backends (e.g.,
|
||||
// wolfSSL)
|
||||
#if defined(CPPHTTPLIB_OPENSSL_SUPPORT) || defined(CPPHTTPLIB_MBEDTLS_SUPPORT)
|
||||
#if defined(CPPHTTPLIB_OPENSSL_SUPPORT) || \
|
||||
defined(CPPHTTPLIB_MBEDTLS_SUPPORT) || defined(CPPHTTPLIB_WOLFSSL_SUPPORT)
|
||||
#define CPPHTTPLIB_SSL_ENABLED
|
||||
#endif
|
||||
|
||||
|
|
@ -440,6 +504,10 @@ using socket_t = int;
|
|||
*/
|
||||
namespace httplib {
|
||||
|
||||
namespace ws {
|
||||
class WebSocket;
|
||||
} // namespace ws
|
||||
|
||||
namespace detail {
|
||||
|
||||
/*
|
||||
|
|
@ -711,6 +779,143 @@ using Match = std::smatch;
|
|||
using DownloadProgress = std::function<bool(size_t current, size_t total)>;
|
||||
using UploadProgress = std::function<bool(size_t current, size_t total)>;
|
||||
|
||||
|
||||
#if __cplusplus >= 201703L
|
||||
|
||||
using any = std::any;
|
||||
using bad_any_cast = std::bad_any_cast;
|
||||
|
||||
template <typename T> T any_cast(const any &a) { return std::any_cast<T>(a); }
|
||||
template <typename T> T any_cast(any &a) { return std::any_cast<T>(a); }
|
||||
template <typename T> T any_cast(any &&a) {
|
||||
return std::any_cast<T>(std::move(a));
|
||||
}
|
||||
template <typename T> const T *any_cast(const any *a) noexcept {
|
||||
return std::any_cast<T>(a);
|
||||
}
|
||||
template <typename T> T *any_cast(any *a) noexcept {
|
||||
return std::any_cast<T>(a);
|
||||
}
|
||||
|
||||
#else // C++11/14 implementation
|
||||
|
||||
class bad_any_cast : public std::bad_cast {
|
||||
public:
|
||||
const char *what() const noexcept override { return "bad any_cast"; }
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
using any_type_id = const void *;
|
||||
|
||||
// Returns a unique per-type ID without RTTI.
|
||||
// The static address is stable across TUs because function templates are
|
||||
// implicitly inline and the ODR merges their statics into one.
|
||||
template <typename T> any_type_id any_typeid() noexcept {
|
||||
static const char id = 0;
|
||||
return &id;
|
||||
}
|
||||
|
||||
struct any_storage {
|
||||
virtual ~any_storage() = default;
|
||||
virtual std::unique_ptr<any_storage> clone() const = 0;
|
||||
virtual any_type_id type_id() const noexcept = 0;
|
||||
};
|
||||
|
||||
template <typename T> struct any_value final : any_storage {
|
||||
T value;
|
||||
template <typename U> explicit any_value(U &&v) : value(std::forward<U>(v)) {}
|
||||
std::unique_ptr<any_storage> clone() const override {
|
||||
return std::unique_ptr<any_storage>(new any_value<T>(value));
|
||||
}
|
||||
any_type_id type_id() const noexcept override { return any_typeid<T>(); }
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
class any {
|
||||
std::unique_ptr<detail::any_storage> storage_;
|
||||
|
||||
public:
|
||||
any() noexcept = default;
|
||||
any(const any &o) : storage_(o.storage_ ? o.storage_->clone() : nullptr) {}
|
||||
any(any &&) noexcept = default;
|
||||
any &operator=(const any &o) {
|
||||
storage_ = o.storage_ ? o.storage_->clone() : nullptr;
|
||||
return *this;
|
||||
}
|
||||
any &operator=(any &&) noexcept = default;
|
||||
|
||||
template <
|
||||
typename T, typename D = typename std::decay<T>::type,
|
||||
typename std::enable_if<!std::is_same<D, any>::value, int>::type = 0>
|
||||
any(T &&v) : storage_(new detail::any_value<D>(std::forward<T>(v))) {}
|
||||
|
||||
template <
|
||||
typename T, typename D = typename std::decay<T>::type,
|
||||
typename std::enable_if<!std::is_same<D, any>::value, int>::type = 0>
|
||||
any &operator=(T &&v) {
|
||||
storage_.reset(new detail::any_value<D>(std::forward<T>(v)));
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool has_value() const noexcept { return storage_ != nullptr; }
|
||||
void reset() noexcept { storage_.reset(); }
|
||||
|
||||
template <typename T> friend T *any_cast(any *a) noexcept;
|
||||
template <typename T> friend const T *any_cast(const any *a) noexcept;
|
||||
};
|
||||
|
||||
template <typename T> T *any_cast(any *a) noexcept {
|
||||
if (!a || !a->storage_) { return nullptr; }
|
||||
if (a->storage_->type_id() != detail::any_typeid<T>()) { return nullptr; }
|
||||
return &static_cast<detail::any_value<T> *>(a->storage_.get())->value;
|
||||
}
|
||||
|
||||
template <typename T> const T *any_cast(const any *a) noexcept {
|
||||
if (!a || !a->storage_) { return nullptr; }
|
||||
if (a->storage_->type_id() != detail::any_typeid<T>()) { return nullptr; }
|
||||
return &static_cast<const detail::any_value<T> *>(a->storage_.get())->value;
|
||||
}
|
||||
|
||||
template <typename T> T any_cast(const any &a) {
|
||||
using U =
|
||||
typename std::remove_cv<typename std::remove_reference<T>::type>::type;
|
||||
const U *p = any_cast<U>(&a);
|
||||
#ifndef CPPHTTPLIB_NO_EXCEPTIONS
|
||||
if (!p) { throw bad_any_cast{}; }
|
||||
#else
|
||||
if (!p) { std::abort(); }
|
||||
#endif
|
||||
return static_cast<T>(*p);
|
||||
}
|
||||
|
||||
template <typename T> T any_cast(any &a) {
|
||||
using U =
|
||||
typename std::remove_cv<typename std::remove_reference<T>::type>::type;
|
||||
U *p = any_cast<U>(&a);
|
||||
#ifndef CPPHTTPLIB_NO_EXCEPTIONS
|
||||
if (!p) { throw bad_any_cast{}; }
|
||||
#else
|
||||
if (!p) { std::abort(); }
|
||||
#endif
|
||||
return static_cast<T>(*p);
|
||||
}
|
||||
|
||||
template <typename T> T any_cast(any &&a) {
|
||||
using U =
|
||||
typename std::remove_cv<typename std::remove_reference<T>::type>::type;
|
||||
U *p = any_cast<U>(&a);
|
||||
#ifndef CPPHTTPLIB_NO_EXCEPTIONS
|
||||
if (!p) { throw bad_any_cast{}; }
|
||||
#else
|
||||
if (!p) { std::abort(); }
|
||||
#endif
|
||||
return static_cast<T>(std::move(*p));
|
||||
}
|
||||
|
||||
#endif // __cplusplus >= 201703L
|
||||
|
||||
struct Response;
|
||||
using ResponseHandler = std::function<bool(const Response &response)>;
|
||||
|
||||
|
|
@ -805,6 +1010,60 @@ struct FormDataProvider {
|
|||
};
|
||||
using FormDataProviderItems = std::vector<FormDataProvider>;
|
||||
|
||||
inline FormDataProvider
|
||||
make_file_provider(const std::string &name, const std::string &filepath,
|
||||
const std::string &filename = std::string(),
|
||||
const std::string &content_type = std::string()) {
|
||||
FormDataProvider fdp;
|
||||
fdp.name = name;
|
||||
fdp.filename = filename.empty() ? filepath : filename;
|
||||
fdp.content_type = content_type;
|
||||
fdp.provider = [filepath](size_t offset, DataSink &sink) -> bool {
|
||||
std::ifstream f(filepath, std::ios::binary);
|
||||
if (!f) { return false; }
|
||||
if (offset > 0) {
|
||||
f.seekg(static_cast<std::streamoff>(offset));
|
||||
if (!f.good()) {
|
||||
sink.done();
|
||||
return true;
|
||||
}
|
||||
}
|
||||
char buf[8192];
|
||||
f.read(buf, sizeof(buf));
|
||||
auto n = static_cast<size_t>(f.gcount());
|
||||
if (n > 0) { return sink.write(buf, n); }
|
||||
sink.done(); // EOF
|
||||
return true;
|
||||
};
|
||||
return fdp;
|
||||
}
|
||||
|
||||
inline std::pair<size_t, ContentProvider>
|
||||
make_file_body(const std::string &filepath) {
|
||||
std::ifstream f(filepath, std::ios::binary | std::ios::ate);
|
||||
if (!f) { return {0, ContentProvider{}}; }
|
||||
auto size = static_cast<size_t>(f.tellg());
|
||||
|
||||
ContentProvider provider = [filepath](size_t offset, size_t length,
|
||||
DataSink &sink) -> bool {
|
||||
std::ifstream f(filepath, std::ios::binary);
|
||||
if (!f) { return false; }
|
||||
f.seekg(static_cast<std::streamoff>(offset));
|
||||
if (!f.good()) { return false; }
|
||||
char buf[8192];
|
||||
while (length > 0) {
|
||||
auto to_read = (std::min)(sizeof(buf), length);
|
||||
f.read(buf, static_cast<std::streamsize>(to_read));
|
||||
auto n = static_cast<size_t>(f.gcount());
|
||||
if (n == 0) { break; }
|
||||
if (!sink.write(buf, n)) { return false; }
|
||||
length -= n;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
return {size, std::move(provider)};
|
||||
}
|
||||
|
||||
using ContentReceiverWithProgress = std::function<bool(
|
||||
const char *data, size_t data_length, size_t offset, size_t total_length)>;
|
||||
|
||||
|
|
@ -1010,6 +1269,10 @@ struct Response {
|
|||
std::string body;
|
||||
std::string location; // Redirect location
|
||||
|
||||
// User-defined context — set by pre-routing/pre-request handlers and read
|
||||
// by route handlers to pass arbitrary data (e.g. decoded auth tokens).
|
||||
std::map<std::string, any> user_data;
|
||||
|
||||
bool has_header(const std::string &key) const;
|
||||
std::string get_header_value(const std::string &key, const char *def = "",
|
||||
size_t id = 0) const;
|
||||
|
|
@ -1115,6 +1378,7 @@ public:
|
|||
virtual bool is_readable() const = 0;
|
||||
virtual bool wait_readable() const = 0;
|
||||
virtual bool wait_writable() const = 0;
|
||||
virtual bool is_peer_alive() const { return wait_writable(); }
|
||||
|
||||
virtual ssize_t read(char *ptr, size_t size) = 0;
|
||||
virtual ssize_t write(const char *ptr, size_t size) = 0;
|
||||
|
|
@ -1124,6 +1388,11 @@ public:
|
|||
|
||||
virtual time_t duration() const = 0;
|
||||
|
||||
virtual void set_read_timeout(time_t sec, time_t usec = 0) {
|
||||
(void)sec;
|
||||
(void)usec;
|
||||
}
|
||||
|
||||
ssize_t write(const char *ptr);
|
||||
ssize_t write(const std::string &s);
|
||||
|
||||
|
|
@ -1146,7 +1415,7 @@ public:
|
|||
|
||||
class ThreadPool final : public TaskQueue {
|
||||
public:
|
||||
explicit ThreadPool(size_t n, size_t mqr = 0);
|
||||
explicit ThreadPool(size_t n, size_t max_n = 0, size_t mqr = 0);
|
||||
ThreadPool(const ThreadPool &) = delete;
|
||||
~ThreadPool() override = default;
|
||||
|
||||
|
|
@ -1154,20 +1423,22 @@ public:
|
|||
void shutdown() override;
|
||||
|
||||
private:
|
||||
struct worker {
|
||||
explicit worker(ThreadPool &pool);
|
||||
void worker(bool is_dynamic);
|
||||
void move_to_finished(std::thread::id id);
|
||||
void cleanup_finished_threads();
|
||||
|
||||
void operator()();
|
||||
|
||||
ThreadPool &pool_;
|
||||
};
|
||||
friend struct worker;
|
||||
|
||||
std::vector<std::thread> threads_;
|
||||
std::list<std::function<void()>> jobs_;
|
||||
size_t base_thread_count_;
|
||||
size_t max_thread_count_;
|
||||
size_t max_queued_requests_;
|
||||
size_t idle_thread_count_;
|
||||
|
||||
bool shutdown_;
|
||||
size_t max_queued_requests_ = 0;
|
||||
|
||||
std::list<std::function<void()>> jobs_;
|
||||
std::vector<std::thread> threads_; // base threads
|
||||
std::list<std::thread> dynamic_threads_; // dynamic threads
|
||||
std::vector<std::thread>
|
||||
finished_threads_; // exited dynamic threads awaiting join
|
||||
|
||||
std::condition_variable cond_;
|
||||
std::mutex mutex_;
|
||||
|
|
@ -1294,6 +1565,11 @@ public:
|
|||
using Expect100ContinueHandler =
|
||||
std::function<int(const Request &, Response &)>;
|
||||
|
||||
using WebSocketHandler =
|
||||
std::function<void(const Request &, ws::WebSocket &)>;
|
||||
using SubProtocolSelector =
|
||||
std::function<std::string(const std::vector<std::string> &protocols)>;
|
||||
|
||||
Server();
|
||||
|
||||
virtual ~Server();
|
||||
|
|
@ -1311,6 +1587,10 @@ public:
|
|||
Server &Delete(const std::string &pattern, HandlerWithContentReader handler);
|
||||
Server &Options(const std::string &pattern, Handler handler);
|
||||
|
||||
Server &WebSocket(const std::string &pattern, WebSocketHandler handler);
|
||||
Server &WebSocket(const std::string &pattern, WebSocketHandler handler,
|
||||
SubProtocolSelector sub_protocol_selector);
|
||||
|
||||
bool set_base_dir(const std::string &dir,
|
||||
const std::string &mount_point = std::string());
|
||||
bool set_mount_point(const std::string &mount_point, const std::string &dir,
|
||||
|
|
@ -1386,7 +1666,8 @@ protected:
|
|||
int remote_port, const std::string &local_addr,
|
||||
int local_port, bool close_connection,
|
||||
bool &connection_closed,
|
||||
const std::function<void(Request &)> &setup_request);
|
||||
const std::function<void(Request &)> &setup_request,
|
||||
bool *websocket_upgraded = nullptr);
|
||||
|
||||
std::atomic<socket_t> svr_sock_{INVALID_SOCKET};
|
||||
|
||||
|
|
@ -1488,6 +1769,14 @@ private:
|
|||
HandlersForContentReader delete_handlers_for_content_reader_;
|
||||
Handlers options_handlers_;
|
||||
|
||||
struct WebSocketHandlerEntry {
|
||||
std::unique_ptr<detail::MatcherBase> matcher;
|
||||
WebSocketHandler handler;
|
||||
SubProtocolSelector sub_protocol_selector;
|
||||
};
|
||||
using WebSocketHandlers = std::vector<WebSocketHandlerEntry>;
|
||||
WebSocketHandlers websocket_handlers_;
|
||||
|
||||
HandlerWithResponse error_handler_;
|
||||
ExceptionHandler exception_handler_;
|
||||
HandlerWithResponse pre_routing_handler_;
|
||||
|
|
@ -2970,6 +3259,36 @@ struct MbedTlsContext {
|
|||
} // namespace tls
|
||||
#endif
|
||||
|
||||
#ifdef CPPHTTPLIB_WOLFSSL_SUPPORT
|
||||
namespace tls {
|
||||
namespace impl {
|
||||
|
||||
// wolfSSL context wrapper (holds WOLFSSL_CTX and related state).
|
||||
// This struct is accessible via tls::impl for use in SSL context
|
||||
// setup callbacks (cast ctx_t to tls::impl::WolfSSLContext*).
|
||||
struct WolfSSLContext {
|
||||
WOLFSSL_CTX *ctx = nullptr;
|
||||
bool is_server = false;
|
||||
bool verify_client = false;
|
||||
bool has_verify_callback = false;
|
||||
std::string ca_pem_data_; // accumulated PEM for get_ca_names/get_ca_certs
|
||||
|
||||
WolfSSLContext();
|
||||
~WolfSSLContext();
|
||||
|
||||
WolfSSLContext(const WolfSSLContext &) = delete;
|
||||
WolfSSLContext &operator=(const WolfSSLContext &) = delete;
|
||||
};
|
||||
|
||||
// CA store for wolfSSL: holds raw PEM bytes to allow reloading into any ctx
|
||||
struct WolfSSLCAStore {
|
||||
std::string pem_data;
|
||||
};
|
||||
|
||||
} // namespace impl
|
||||
} // namespace tls
|
||||
#endif
|
||||
|
||||
#endif // CPPHTTPLIB_SSL_ENABLED
|
||||
|
||||
namespace stream {
|
||||
|
|
@ -3335,6 +3654,143 @@ private:
|
|||
|
||||
} // namespace sse
|
||||
|
||||
namespace ws {
|
||||
|
||||
enum class Opcode : uint8_t {
|
||||
Continuation = 0x0,
|
||||
Text = 0x1,
|
||||
Binary = 0x2,
|
||||
Close = 0x8,
|
||||
Ping = 0x9,
|
||||
Pong = 0xA,
|
||||
};
|
||||
|
||||
enum class CloseStatus : uint16_t {
|
||||
Normal = 1000,
|
||||
GoingAway = 1001,
|
||||
ProtocolError = 1002,
|
||||
UnsupportedData = 1003,
|
||||
NoStatus = 1005,
|
||||
Abnormal = 1006,
|
||||
InvalidPayload = 1007,
|
||||
PolicyViolation = 1008,
|
||||
MessageTooBig = 1009,
|
||||
MandatoryExtension = 1010,
|
||||
InternalError = 1011,
|
||||
};
|
||||
|
||||
enum ReadResult : int { Fail = 0, Text = 1, Binary = 2 };
|
||||
|
||||
class WebSocket {
|
||||
public:
|
||||
WebSocket(const WebSocket &) = delete;
|
||||
WebSocket &operator=(const WebSocket &) = delete;
|
||||
~WebSocket();
|
||||
|
||||
ReadResult read(std::string &msg);
|
||||
bool send(const std::string &data);
|
||||
bool send(const char *data, size_t len);
|
||||
void close(CloseStatus status = CloseStatus::Normal,
|
||||
const std::string &reason = "");
|
||||
const Request &request() const;
|
||||
bool is_open() const;
|
||||
|
||||
private:
|
||||
friend class httplib::Server;
|
||||
friend class WebSocketClient;
|
||||
|
||||
WebSocket(Stream &strm, const Request &req, bool is_server)
|
||||
: strm_(strm), req_(req), is_server_(is_server) {
|
||||
start_heartbeat();
|
||||
}
|
||||
|
||||
WebSocket(std::unique_ptr<Stream> &&owned_strm, const Request &req,
|
||||
bool is_server)
|
||||
: strm_(*owned_strm), owned_strm_(std::move(owned_strm)), req_(req),
|
||||
is_server_(is_server) {
|
||||
start_heartbeat();
|
||||
}
|
||||
|
||||
void start_heartbeat();
|
||||
bool send_frame(Opcode op, const char *data, size_t len, bool fin = true);
|
||||
|
||||
Stream &strm_;
|
||||
std::unique_ptr<Stream> owned_strm_;
|
||||
Request req_;
|
||||
bool is_server_;
|
||||
std::atomic<bool> closed_{false};
|
||||
std::mutex write_mutex_;
|
||||
std::thread ping_thread_;
|
||||
std::mutex ping_mutex_;
|
||||
std::condition_variable ping_cv_;
|
||||
};
|
||||
|
||||
class WebSocketClient {
|
||||
public:
|
||||
explicit WebSocketClient(const std::string &scheme_host_port_path,
|
||||
const Headers &headers = {});
|
||||
|
||||
~WebSocketClient();
|
||||
WebSocketClient(const WebSocketClient &) = delete;
|
||||
WebSocketClient &operator=(const WebSocketClient &) = delete;
|
||||
|
||||
bool is_valid() const;
|
||||
|
||||
bool connect();
|
||||
ReadResult read(std::string &msg);
|
||||
bool send(const std::string &data);
|
||||
bool send(const char *data, size_t len);
|
||||
void close(CloseStatus status = CloseStatus::Normal,
|
||||
const std::string &reason = "");
|
||||
bool is_open() const;
|
||||
const std::string &subprotocol() const;
|
||||
void set_read_timeout(time_t sec, time_t usec = 0);
|
||||
void set_write_timeout(time_t sec, time_t usec = 0);
|
||||
|
||||
#ifdef CPPHTTPLIB_SSL_ENABLED
|
||||
void set_ca_cert_path(const std::string &path);
|
||||
void set_ca_cert_store(tls::ca_store_t store);
|
||||
void enable_server_certificate_verification(bool enabled);
|
||||
#endif
|
||||
|
||||
private:
|
||||
void shutdown_and_close();
|
||||
bool create_stream(std::unique_ptr<Stream> &strm);
|
||||
|
||||
std::string host_;
|
||||
int port_;
|
||||
std::string path_;
|
||||
Headers headers_;
|
||||
std::string subprotocol_;
|
||||
bool is_valid_ = false;
|
||||
socket_t sock_ = INVALID_SOCKET;
|
||||
std::unique_ptr<WebSocket> ws_;
|
||||
time_t read_timeout_sec_ = CPPHTTPLIB_WEBSOCKET_READ_TIMEOUT_SECOND;
|
||||
time_t read_timeout_usec_ = 0;
|
||||
time_t write_timeout_sec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND;
|
||||
time_t write_timeout_usec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND;
|
||||
|
||||
#ifdef CPPHTTPLIB_SSL_ENABLED
|
||||
bool is_ssl_ = false;
|
||||
tls::ctx_t tls_ctx_ = nullptr;
|
||||
tls::session_t tls_session_ = nullptr;
|
||||
std::string ca_cert_file_path_;
|
||||
tls::ca_store_t ca_cert_store_ = nullptr;
|
||||
bool server_certificate_verification_ = true;
|
||||
#endif
|
||||
};
|
||||
|
||||
namespace impl {
|
||||
|
||||
bool is_valid_utf8(const std::string &s);
|
||||
|
||||
bool read_websocket_frame(Stream &strm, Opcode &opcode, std::string &payload,
|
||||
bool &fin, bool expect_masked, size_t max_len);
|
||||
|
||||
} // namespace impl
|
||||
|
||||
} // namespace ws
|
||||
|
||||
|
||||
} // namespace httplib
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue