Compare commits
19 Commits
75225a93d6
...
6331fddba5
| Author | SHA1 | Date |
|---|---|---|
|
|
6331fddba5 | |
|
|
6de97b9d3e | |
|
|
5a0ed5150a | |
|
|
8710e5f9b9 | |
|
|
1d6d4cf7a5 | |
|
|
744c0c7310 | |
|
|
0356e33aaf | |
|
|
6422036fcb | |
|
|
296bc0538b | |
|
|
6b949d1078 | |
|
|
84f82e846c | |
|
|
e1cb817483 | |
|
|
88d5f8ffc3 | |
|
|
d43375ff7f | |
|
|
2b86e5cae6 | |
|
|
88458164c7 | |
|
|
4951250235 | |
|
|
82764c341a | |
|
|
cfd5d05816 |
|
|
@ -150,16 +150,15 @@ jobs:
|
|||
- name: Dawn Dependency
|
||||
id: dawn-depends
|
||||
run: |
|
||||
DAWN_VERSION="v2.0.0"
|
||||
DAWN_OWNER="reeselevine"
|
||||
DAWN_VERSION="v20260317.182325"
|
||||
DAWN_OWNER="google"
|
||||
DAWN_REPO="dawn"
|
||||
DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release"
|
||||
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip"
|
||||
curl -L -o artifact.zip \
|
||||
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip"
|
||||
DAWN_ASSET_NAME="Dawn-18eb229ef5f707c1464cc581252e7603c73a3ef0-macos-latest-Release"
|
||||
echo "Fetching release asset from https://github.com/google/dawn/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.tar.gz"
|
||||
curl -L -o artifact.tar.gz \
|
||||
"https://github.com/google/dawn/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.tar.gz"
|
||||
mkdir dawn
|
||||
unzip artifact.zip
|
||||
tar -xvf ${DAWN_ASSET_NAME}.tar.gz -C dawn --strip-components=1
|
||||
tar -xvf artifact.tar.gz -C dawn --strip-components=1
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
|
|
@ -384,16 +383,15 @@ jobs:
|
|||
id: dawn-depends
|
||||
run: |
|
||||
sudo apt-get install -y libxrandr-dev libxinerama-dev libxcursor-dev mesa-common-dev libx11-xcb-dev libxi-dev
|
||||
DAWN_VERSION="v2.0.0"
|
||||
DAWN_OWNER="reeselevine"
|
||||
DAWN_VERSION="v20260317.182325"
|
||||
DAWN_OWNER="google"
|
||||
DAWN_REPO="dawn"
|
||||
DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release"
|
||||
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip"
|
||||
curl -L -o artifact.zip \
|
||||
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip"
|
||||
DAWN_ASSET_NAME="Dawn-18eb229ef5f707c1464cc581252e7603c73a3ef0-ubuntu-latest-Release"
|
||||
echo "Fetching release asset from https://github.com/google/dawn/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.tar.gz"
|
||||
curl -L -o artifact.tar.gz \
|
||||
"https://github.com/google/dawn/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.tar.gz"
|
||||
mkdir dawn
|
||||
unzip artifact.zip
|
||||
tar -xvf ${DAWN_ASSET_NAME}.tar.gz -C dawn --strip-components=1
|
||||
tar -xvf artifact.tar.gz -C dawn --strip-components=1
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
|
|
@ -427,7 +425,7 @@ jobs:
|
|||
|
||||
- name: Fetch emdawnwebgpu
|
||||
run: |
|
||||
DAWN_TAG="v20251027.212519"
|
||||
DAWN_TAG="v20260317.182325"
|
||||
EMDAWN_PKG="emdawnwebgpu_pkg-${DAWN_TAG}.zip"
|
||||
echo "Downloading ${EMDAWN_PKG}"
|
||||
curl -L -o emdawn.zip \
|
||||
|
|
|
|||
30
ci/run.sh
30
ci/run.sh
|
|
@ -151,35 +151,7 @@ fi
|
|||
|
||||
if [ -n "${GG_BUILD_KLEIDIAI}" ]; then
|
||||
echo ">>===== Enabling KleidiAI support"
|
||||
|
||||
CANDIDATES=(
|
||||
"armv9-a+dotprod+i8mm+sve2"
|
||||
"armv9-a+dotprod+i8mm"
|
||||
"armv8.6-a+dotprod+i8mm"
|
||||
"armv8.2-a+dotprod"
|
||||
)
|
||||
CPU=""
|
||||
|
||||
for cpu in "${CANDIDATES[@]}"; do
|
||||
if echo 'int main(){}' | ${CXX:-c++} -march="$cpu" -x c++ - -c -o /dev/null >/dev/null 2>&1; then
|
||||
CPU="$cpu"
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
if [ -z "$CPU" ]; then
|
||||
echo "ERROR: None of the required ARM baselines (armv9/armv8.6/armv8.2 + dotprod) are supported by this compiler."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo ">>===== Using ARM baseline: ${CPU}"
|
||||
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA:+$CMAKE_EXTRA } \
|
||||
-DGGML_NATIVE=OFF \
|
||||
-DGGML_CPU_KLEIDIAI=ON \
|
||||
-DGGML_CPU_AARCH64=ON \
|
||||
-DGGML_CPU_ARM_ARCH=${CPU} \
|
||||
-DBUILD_SHARED_LIBS=OFF"
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA:+$CMAKE_EXTRA } -DGGML_CPU_KLEIDIAI=ON"
|
||||
fi
|
||||
|
||||
if [ ! -z ${GG_BUILD_BLAS} ]; then
|
||||
|
|
|
|||
100
common/chat.cpp
100
common/chat.cpp
|
|
@ -1274,11 +1274,12 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
|
|||
return data;
|
||||
}
|
||||
|
||||
// LFM2 format:
|
||||
// - Reasoning: <think>{reasoning}</think> (optional, only if enable_thinking is true)
|
||||
// - Content: text after reasoning (optional)
|
||||
// - Tool calls: <|tool_call_start|>[function_name(arg1="value1", arg2="value2")]<|tool_call_end|>
|
||||
// Tool calls can appear multiple times (parallel tool calls)
|
||||
// LFM2 format: uses <|tool_list_start|>[...]<|tool_list_end|> in system prompt
|
||||
// and <|tool_call_start|>[name(arg="val")]<|tool_call_end|> for tool calls.
|
||||
// - Reasoning: <think>{reasoning}</think> (optional)
|
||||
// - Content: text before a tool call (optional)
|
||||
// - Tool calls: Python-style, e.g. [function_name(arg1="value1", arg2="value2")]
|
||||
// Tool calls can appear multiple times (parallel tool calls supported)
|
||||
static common_chat_params common_chat_params_init_lfm2(const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
|
@ -1319,9 +1320,9 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
|
|||
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
return generation_prompt + reasoning + p.content(p.rest()) + end;
|
||||
}
|
||||
|
||||
auto tool_calls = p.rule("tool-calls",
|
||||
p.trigger_rule("tool-call", p.literal(TOOL_CALL_START) +
|
||||
p.trigger_rule("tool-call",
|
||||
p.literal(TOOL_CALL_START) +
|
||||
p.python_style_tool_calls(inputs.tools, inputs.parallel_tool_calls) +
|
||||
p.literal(TOOL_CALL_END)
|
||||
)
|
||||
|
|
@ -1349,6 +1350,80 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
|
|||
{ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, TOOL_CALL_START }
|
||||
};
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
// LFM2.5 format: uses plain "List of tools: [...]" in system prompt, no wrapper tokens.
|
||||
// Tool calls are bare [name(arg="val")], though model may optionally emit <|tool_call_start|>.
|
||||
// - Reasoning: <think>{reasoning}</think> (optional)
|
||||
// - Content: text before a tool call (optional)
|
||||
// - Tool calls: Python-style, e.g. [function_name(arg1="value1", arg2="value2")]
|
||||
// Tool calls can appear multiple times (parallel tool calls supported)
|
||||
static common_chat_params common_chat_params_init_lfm2_5(const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.preserved_tokens = {
|
||||
"<|tool_call_start|>",
|
||||
"<|tool_call_end|>",
|
||||
"<think>",
|
||||
"</think>",
|
||||
};
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
|
||||
|
||||
const std::string THINK_START = "<think>";
|
||||
const std::string THINK_END = "</think>";
|
||||
|
||||
data.thinking_start_tag = THINK_START;
|
||||
data.thinking_end_tag = THINK_END;
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START);
|
||||
auto end = p.end();
|
||||
|
||||
auto reasoning = p.eps();
|
||||
if (extract_reasoning && inputs.enable_thinking) {
|
||||
reasoning = p.optional(THINK_START + p.reasoning(p.until(THINK_END)) + THINK_END);
|
||||
}
|
||||
|
||||
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
return generation_prompt + reasoning + p.content(p.rest()) + end;
|
||||
}
|
||||
|
||||
auto tool_calls = p.rule("tool-calls",
|
||||
p.trigger_rule("tool-call",
|
||||
p.python_style_tool_calls(inputs.tools, inputs.parallel_tool_calls)
|
||||
)
|
||||
);
|
||||
|
||||
auto content = p.content(p.until_one_of({"<|tool_call_start|>", "["}));
|
||||
auto maybe_start = p.optional(p.literal("<|tool_call_start|>"));
|
||||
return generation_prompt + reasoning + content + maybe_start + tool_calls + end;
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
|
||||
if (include_grammar) {
|
||||
data.grammar_lazy = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
auto schema = function.at("parameters");
|
||||
builder.resolve_refs(schema);
|
||||
});
|
||||
parser.build_grammar(builder, data.grammar_lazy);
|
||||
});
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const std::string name = tool.at("function").at("name");
|
||||
data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[" + name + "(" });
|
||||
});
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
|
@ -1530,14 +1605,21 @@ static std::optional<common_chat_params> try_specialized_template(
|
|||
return common_chat_params_init_kimi_k2(tmpl, params);
|
||||
}
|
||||
|
||||
// LFM2 - uses <|tool_list_start|>/<|tool_list_end|> markers and <|tool_call_start|>[name(args)]<|tool_call_end|> format
|
||||
// Detection: template has "<|tool_list_start|>" and "<|tool_list_end|>" markers
|
||||
// LFM2 format detection: template uses <|tool_list_start|>[...]<|tool_list_end|> around the tool list
|
||||
// and <|tool_call_start|>[...]<|tool_call_end|> around each tool call
|
||||
if (src.find("<|tool_list_start|>") != std::string::npos &&
|
||||
src.find("<|tool_list_end|>") != std::string::npos) {
|
||||
LOG_DBG("Using specialized template: LFM2\n");
|
||||
return common_chat_params_init_lfm2(tmpl, params);
|
||||
}
|
||||
|
||||
// LFM2.5 format detection: template uses plain "List of tools: [...]" with no special tokens
|
||||
if (src.find("List of tools: [") != std::string::npos &&
|
||||
src.find("<|tool_list_start|>") == std::string::npos) {
|
||||
LOG_DBG("Using specialized template: LFM2.5\n");
|
||||
return common_chat_params_init_lfm2_5(tmpl, params);
|
||||
}
|
||||
|
||||
// GigaChatV3 format detection
|
||||
if (src.find("<|role_sep|>") != std::string::npos &&
|
||||
src.find("<|message_sep|>") != std::string::npos &&
|
||||
|
|
|
|||
|
|
@ -728,7 +728,7 @@ To read documentation for how to build on Android, [click here](./android.md)
|
|||
|
||||
## WebGPU [In Progress]
|
||||
|
||||
The WebGPU backend relies on [Dawn](https://dawn.googlesource.com/dawn). Follow the instructions [here](https://dawn.googlesource.com/dawn/+/refs/heads/main/docs/quickstart-cmake.md) to install Dawn locally so that llama.cpp can find it using CMake. The current implementation is up-to-date with Dawn commit `bed1a61`.
|
||||
The WebGPU backend relies on [Dawn](https://dawn.googlesource.com/dawn). Follow the instructions [here](https://dawn.googlesource.com/dawn/+/refs/heads/main/docs/quickstart-cmake.md) to install Dawn locally so that llama.cpp can find it using CMake. The current implementation is up-to-date with Dawn commit `18eb229`.
|
||||
|
||||
In the llama.cpp directory, build with CMake:
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ project("ggml" C CXX ASM)
|
|||
### GGML Version
|
||||
set(GGML_VERSION_MAJOR 0)
|
||||
set(GGML_VERSION_MINOR 9)
|
||||
set(GGML_VERSION_PATCH 9)
|
||||
set(GGML_VERSION_PATCH 10)
|
||||
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
|
||||
|
||||
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)
|
||||
|
|
@ -166,15 +166,16 @@ if (NOT MSVC)
|
|||
option(GGML_AMX_INT8 "ggml: enable AMX-INT8" OFF)
|
||||
option(GGML_AMX_BF16 "ggml: enable AMX-BF16" OFF)
|
||||
endif()
|
||||
option(GGML_LASX "ggml: enable lasx" ON)
|
||||
option(GGML_LSX "ggml: enable lsx" ON)
|
||||
option(GGML_RVV "ggml: enable rvv" ON)
|
||||
option(GGML_RV_ZFH "ggml: enable riscv zfh" ON)
|
||||
option(GGML_RV_ZVFH "ggml: enable riscv zvfh" ON)
|
||||
option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON)
|
||||
option(GGML_RV_ZIHINTPAUSE "ggml: enable riscv zihintpause " ON)
|
||||
option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF)
|
||||
option(GGML_VXE "ggml: enable vxe" ${GGML_NATIVE})
|
||||
option(GGML_LASX "ggml: enable lasx" ON)
|
||||
option(GGML_LSX "ggml: enable lsx" ON)
|
||||
option(GGML_RVV "ggml: enable rvv" ON)
|
||||
option(GGML_RV_ZFH "ggml: enable riscv zfh" ON)
|
||||
option(GGML_RV_ZVFH "ggml: enable riscv zvfh" ON)
|
||||
option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON)
|
||||
option(GGML_RV_ZIHINTPAUSE "ggml: enable riscv zihintpause" ON)
|
||||
option(GGML_RV_ZVFBFWMA "ggml: enable riscv zvfbfwma" OFF)
|
||||
option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF)
|
||||
option(GGML_VXE "ggml: enable vxe" ${GGML_NATIVE})
|
||||
|
||||
option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF)
|
||||
set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM")
|
||||
|
|
|
|||
|
|
@ -2350,11 +2350,15 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|||
case GGML_OP_FLASH_ATTN_BACK:
|
||||
case GGML_OP_SSM_CONV:
|
||||
case GGML_OP_SSM_SCAN:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
const int64_t n_heads = node->src[1]->ne[1];
|
||||
n_tasks = MIN(n_threads, n_heads);
|
||||
} break;
|
||||
case GGML_OP_WIN_PART:
|
||||
case GGML_OP_WIN_UNPART:
|
||||
|
|
|
|||
|
|
@ -180,44 +180,49 @@ inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
|
|||
}
|
||||
#endif
|
||||
|
||||
#if defined(__riscv_zvfh)
|
||||
template <>
|
||||
inline vfloat32m1_t madd(vfloat16mf2_t a, vfloat16mf2_t b, vfloat32m1_t c) {
|
||||
return __riscv_vfwmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
|
||||
}
|
||||
inline vfloat32m2_t madd(vfloat16m1_t a, vfloat16m1_t b, vfloat32m2_t c) {
|
||||
return __riscv_vfwmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
|
||||
}
|
||||
inline vfloat32m4_t madd(vfloat16m2_t a, vfloat16m2_t b, vfloat32m4_t c) {
|
||||
return __riscv_vfwmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
|
||||
}
|
||||
inline vfloat32m8_t madd(vfloat16m4_t a, vfloat16m4_t b, vfloat32m8_t c) {
|
||||
return __riscv_vfwmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
|
||||
}
|
||||
inline vfloat32m1_t madd(vfloat32m1_t a, vfloat32m1_t b, vfloat32m1_t c) {
|
||||
#if defined(__riscv_v_intrinsic)
|
||||
template <> inline vfloat32m1_t madd(vfloat32m1_t a, vfloat32m1_t b, vfloat32m1_t c) {
|
||||
return __riscv_vfmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
|
||||
}
|
||||
inline vfloat32m2_t madd(vfloat32m2_t a, vfloat32m2_t b, vfloat32m2_t c) {
|
||||
template <> inline vfloat32m2_t madd(vfloat32m2_t a, vfloat32m2_t b, vfloat32m2_t c) {
|
||||
return __riscv_vfmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
|
||||
}
|
||||
inline vfloat32m4_t madd(vfloat32m4_t a, vfloat32m4_t b, vfloat32m4_t c) {
|
||||
template <> inline vfloat32m4_t madd(vfloat32m4_t a, vfloat32m4_t b, vfloat32m4_t c) {
|
||||
return __riscv_vfmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
|
||||
}
|
||||
inline vfloat32m8_t madd(vfloat32m8_t a, vfloat32m8_t b, vfloat32m8_t c) {
|
||||
template <> inline vfloat32m8_t madd(vfloat32m8_t a, vfloat32m8_t b, vfloat32m8_t c) {
|
||||
return __riscv_vfmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(__riscv_zvfh)
|
||||
template <> inline vfloat32m1_t madd(vfloat16mf2_t a, vfloat16mf2_t b, vfloat32m1_t c) {
|
||||
return __riscv_vfwmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
|
||||
}
|
||||
template <> inline vfloat32m2_t madd(vfloat16m1_t a, vfloat16m1_t b, vfloat32m2_t c) {
|
||||
return __riscv_vfwmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
|
||||
}
|
||||
template <> inline vfloat32m4_t madd(vfloat16m2_t a, vfloat16m2_t b, vfloat32m4_t c) {
|
||||
return __riscv_vfwmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
|
||||
}
|
||||
template <> inline vfloat32m8_t madd(vfloat16m4_t a, vfloat16m4_t b, vfloat32m8_t c) {
|
||||
return __riscv_vfwmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(__riscv_zvfbfwma)
|
||||
inline vfloat32m1_t madd(vbfloat16mf2_t a, vbfloat16mf2_t b, vfloat32m1_t c) {
|
||||
template <> inline vfloat32m1_t madd(vbfloat16mf2_t a, vbfloat16mf2_t b, vfloat32m1_t c) {
|
||||
return __riscv_vfwmaccbf16_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
|
||||
}
|
||||
inline vfloat32m2_t madd(vbfloat16m1_t a, vbfloat16m1_t b, vfloat32m2_t c) {
|
||||
template <> inline vfloat32m2_t madd(vbfloat16m1_t a, vbfloat16m1_t b, vfloat32m2_t c) {
|
||||
return __riscv_vfwmaccbf16_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
|
||||
}
|
||||
inline vfloat32m4_t madd(vbfloat16m2_t a, vbfloat16m2_t b, vfloat32m4_t c) {
|
||||
template <> inline vfloat32m4_t madd(vbfloat16m2_t a, vbfloat16m2_t b, vfloat32m4_t c) {
|
||||
return __riscv_vfwmaccbf16_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
|
||||
}
|
||||
template <> inline vfloat32m8_t madd(vbfloat16m4_t a, vbfloat16m4_t b, vfloat32m8_t c) {
|
||||
return __riscv_vfwmaccbf16_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
@ -272,7 +277,7 @@ inline float hsum(__m512 x) {
|
|||
}
|
||||
#endif // __AVX512F__
|
||||
|
||||
#if defined(__riscv_zvfh)
|
||||
#if defined(__riscv_v_intrinsic)
|
||||
inline float hsum(vfloat32m1_t x) {
|
||||
return __riscv_vfmv_f_s_f32m1_f32(
|
||||
__riscv_vfredusum_vs_f32m1_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m1()));
|
||||
|
|
@ -379,19 +384,7 @@ template <> inline __m256bh load(const float *p) {
|
|||
}
|
||||
#endif
|
||||
|
||||
#if defined(__riscv_zvfh)
|
||||
template <> inline vfloat16mf2_t load(const ggml_fp16_t *p) {
|
||||
return __riscv_vle16_v_f16mf2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16mf2());
|
||||
}
|
||||
template <> inline vfloat16m1_t load(const ggml_fp16_t *p) {
|
||||
return __riscv_vle16_v_f16m1(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m1());
|
||||
}
|
||||
template <> inline vfloat16m2_t load(const ggml_fp16_t *p) {
|
||||
return __riscv_vle16_v_f16m2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m2());
|
||||
}
|
||||
template <> inline vfloat16m4_t load(const ggml_fp16_t *p) {
|
||||
return __riscv_vle16_v_f16m4(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m4());
|
||||
}
|
||||
#if defined(__riscv_v_intrinsic)
|
||||
template <> inline vfloat32m1_t load(const float *p) {
|
||||
return __riscv_vle32_v_f32m1(p, __riscv_vsetvlmax_e32m1());
|
||||
}
|
||||
|
|
@ -406,6 +399,21 @@ template <> inline vfloat32m8_t load(const float *p) {
|
|||
}
|
||||
#endif
|
||||
|
||||
#if defined(__riscv_zvfh)
|
||||
template <> inline vfloat16mf2_t load(const ggml_fp16_t *p) {
|
||||
return __riscv_vle16_v_f16mf2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16mf2());
|
||||
}
|
||||
template <> inline vfloat16m1_t load(const ggml_fp16_t *p) {
|
||||
return __riscv_vle16_v_f16m1(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m1());
|
||||
}
|
||||
template <> inline vfloat16m2_t load(const ggml_fp16_t *p) {
|
||||
return __riscv_vle16_v_f16m2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m2());
|
||||
}
|
||||
template <> inline vfloat16m4_t load(const ggml_fp16_t *p) {
|
||||
return __riscv_vle16_v_f16m4(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m4());
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(__riscv_zvfbfwma)
|
||||
template <> inline vbfloat16mf2_t load(const ggml_bf16_t *p) {
|
||||
return __riscv_vle16_v_bf16mf2(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16mf2());
|
||||
|
|
@ -416,23 +424,14 @@ template <> inline vbfloat16m1_t load(const ggml_bf16_t *p) {
|
|||
template <> inline vbfloat16m2_t load(const ggml_bf16_t *p) {
|
||||
return __riscv_vle16_v_bf16m2(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m2());
|
||||
}
|
||||
template <> inline vbfloat16m4_t load(const ggml_bf16_t *p) {
|
||||
return __riscv_vle16_v_bf16m4(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m4());
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(__riscv_zvfh)
|
||||
#if defined(__riscv_v_intrinsic)
|
||||
template <typename T> T set_zero();
|
||||
|
||||
template <> inline vfloat16mf2_t set_zero() {
|
||||
return __riscv_vfmv_v_f_f16mf2(0, __riscv_vsetvlmax_e16mf2());
|
||||
}
|
||||
template <> inline vfloat16m1_t set_zero() {
|
||||
return __riscv_vfmv_v_f_f16m1(0, __riscv_vsetvlmax_e16m1());
|
||||
}
|
||||
template <> inline vfloat16m2_t set_zero() {
|
||||
return __riscv_vfmv_v_f_f16m2(0, __riscv_vsetvlmax_e16m2());
|
||||
}
|
||||
template <> inline vfloat16m4_t set_zero() {
|
||||
return __riscv_vfmv_v_f_f16m4(0, __riscv_vsetvlmax_e16m4());
|
||||
}
|
||||
template <> inline vfloat32m1_t set_zero() {
|
||||
return __riscv_vfmv_v_f_f32m1(0.0f, __riscv_vsetvlmax_e32m1());
|
||||
}
|
||||
|
|
@ -449,14 +448,22 @@ template <> inline vfloat32m8_t set_zero() {
|
|||
|
||||
#if defined(__riscv_v_intrinsic)
|
||||
template <typename T> size_t vlmax() {
|
||||
if constexpr (std::is_same_v<T, vfloat16mf2_t>) { return __riscv_vsetvlmax_e16mf2(); }
|
||||
else if constexpr (std::is_same_v<T, vfloat16m1_t>) { return __riscv_vsetvlmax_e16m1(); }
|
||||
else if constexpr (std::is_same_v<T, vfloat16m2_t>) { return __riscv_vsetvlmax_e16m2(); }
|
||||
else if constexpr (std::is_same_v<T, vfloat16m4_t>) { return __riscv_vsetvlmax_e16m4(); }
|
||||
else if constexpr (std::is_same_v<T, vfloat32m1_t>) { return __riscv_vsetvlmax_e32m1(); }
|
||||
if constexpr (std::is_same_v<T, vfloat32m1_t>) { return __riscv_vsetvlmax_e32m1(); }
|
||||
else if constexpr (std::is_same_v<T, vfloat32m2_t>) { return __riscv_vsetvlmax_e32m2(); }
|
||||
else if constexpr (std::is_same_v<T, vfloat32m4_t>) { return __riscv_vsetvlmax_e32m4(); }
|
||||
else if constexpr (std::is_same_v<T, vfloat32m8_t>) { return __riscv_vsetvlmax_e32m8(); }
|
||||
#if defined (__riscv_zvfh)
|
||||
else if constexpr (std::is_same_v<T, vfloat16mf2_t>) { return __riscv_vsetvlmax_e16mf2(); }
|
||||
else if constexpr (std::is_same_v<T, vfloat16m1_t>) { return __riscv_vsetvlmax_e16m1(); }
|
||||
else if constexpr (std::is_same_v<T, vfloat16m2_t>) { return __riscv_vsetvlmax_e16m2(); }
|
||||
else if constexpr (std::is_same_v<T, vfloat16m4_t>) { return __riscv_vsetvlmax_e16m4(); }
|
||||
#endif
|
||||
#if defined (__riscv_zvfbfwma)
|
||||
else if constexpr (std::is_same_v<T, vbfloat16mf2_t>) { return __riscv_vsetvlmax_e16mf2(); }
|
||||
else if constexpr (std::is_same_v<T, vbfloat16m1_t>) { return __riscv_vsetvlmax_e16m1(); }
|
||||
else if constexpr (std::is_same_v<T, vbfloat16m2_t>) { return __riscv_vsetvlmax_e16m2(); }
|
||||
else if constexpr (std::is_same_v<T, vbfloat16m4_t>) { return __riscv_vsetvlmax_e16m4(); }
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
|
|
@ -3740,7 +3747,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
|||
params->ith, params->nth};
|
||||
tb.matmul(m, n);
|
||||
return true;
|
||||
#elif defined(__riscv_zvfh)
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
#if LMUL == 1
|
||||
tinyBLAS_RVV<vfloat32m1_t, vfloat32m1_t, float, float, float> tb{ params,
|
||||
k, (const float *)A, lda,
|
||||
|
|
@ -3804,23 +3811,25 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
|||
return true;
|
||||
}
|
||||
#elif defined(__riscv_zvfbfwma)
|
||||
#if LMUL == 1
|
||||
tinyBLAS_RVV<vfloat32m1_t, vbfloat16mf2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
|
||||
k, (const ggml_bf16_t *)A, lda,
|
||||
(const ggml_bf16_t *)B, ldb,
|
||||
(float *)C, ldc};
|
||||
#elif LMUL == 2
|
||||
tinyBLAS_RVV<vfloat32m2_t, vbfloat16m1_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
|
||||
k, (const ggml_bf16_t *)A, lda,
|
||||
(const ggml_bf16_t *)B, ldb,
|
||||
(float *)C, ldc};
|
||||
#else // LMUL = 4
|
||||
tinyBLAS_RVV<vfloat32m4_t, vbfloat16m2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
|
||||
k, (const ggml_bf16_t *)A, lda,
|
||||
(const ggml_bf16_t *)B, ldb,
|
||||
(float *)C, ldc};
|
||||
#endif
|
||||
return tb.matmul(m, n);
|
||||
if (Btype == GGML_TYPE_BF16) {
|
||||
#if LMUL == 1
|
||||
tinyBLAS_RVV<vfloat32m1_t, vbfloat16mf2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
|
||||
k, (const ggml_bf16_t *)A, lda,
|
||||
(const ggml_bf16_t *)B, ldb,
|
||||
(float *)C, ldc};
|
||||
#elif LMUL == 2
|
||||
tinyBLAS_RVV<vfloat32m2_t, vbfloat16m1_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
|
||||
k, (const ggml_bf16_t *)A, lda,
|
||||
(const ggml_bf16_t *)B, ldb,
|
||||
(float *)C, ldc};
|
||||
#else // LMUL = 4
|
||||
tinyBLAS_RVV<vfloat32m4_t, vbfloat16m2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
|
||||
k, (const ggml_bf16_t *)A, lda,
|
||||
(const ggml_bf16_t *)B, ldb,
|
||||
(float *)C, ldc};
|
||||
#endif
|
||||
return tb.matmul(m, n);
|
||||
}
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9953,13 +9953,9 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
|
|||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
if (ith >= HEADS) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int h_start = (HEADS * ith) / nth;
|
||||
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
|
||||
(HEADS * (ith + 1)) / nth : HEADS;
|
||||
const int h_start = (HEADS * (ith )) / nth;
|
||||
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
|
||||
(HEADS * (ith + 1)) / nth : HEADS;
|
||||
|
||||
float * k = (float *) dst->src[0]->data;
|
||||
float * v = (float *) dst->src[1]->data;
|
||||
|
|
@ -10170,13 +10166,9 @@ static void ggml_compute_forward_gla_f32(
|
|||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
if (ith >= HEADS) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int h_start = (HEADS * ith) / nth;
|
||||
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
|
||||
(HEADS * (ith + 1)) / nth : HEADS;
|
||||
const int h_start = (HEADS * (ith )) / nth;
|
||||
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
|
||||
(HEADS * (ith + 1)) / nth : HEADS;
|
||||
|
||||
float * k = (float *) dst->src[0]->data;
|
||||
float * v = (float *) dst->src[1]->data;
|
||||
|
|
@ -10633,13 +10625,9 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
|
|||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
if (ith >= HEADS) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int h_start = (HEADS * ith) / nth;
|
||||
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
|
||||
(HEADS * (ith + 1)) / nth : HEADS;
|
||||
const int h_start = (HEADS * (ith )) / nth;
|
||||
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
|
||||
(HEADS * (ith + 1)) / nth : HEADS;
|
||||
|
||||
float * r = (float *) dst->src[0]->data;
|
||||
float * w = (float *) dst->src[1]->data;
|
||||
|
|
|
|||
|
|
@ -126,7 +126,7 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
|
|||
const int ggml_f16_epr = sve_register_length / 16; // running when 16
|
||||
const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers
|
||||
|
||||
const int np = (n & ~(ggml_f16_step - 1));
|
||||
int np = (n & ~(ggml_f16_step - 1));
|
||||
|
||||
svfloat16_t sum_00 = svdup_n_f16(0.0f);
|
||||
svfloat16_t sum_01 = svdup_n_f16(0.0f);
|
||||
|
|
@ -224,71 +224,75 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
|
|||
}
|
||||
GGML_F16x_VEC_REDUCE(sumf[0], sum_00, sum_01, sum_02, sum_03);
|
||||
GGML_F16x_VEC_REDUCE(sumf[1], sum_10, sum_11, sum_12, sum_13);
|
||||
np = n;
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
#if defined(__riscv_zvfh)
|
||||
size_t vl = __riscv_vsetvlmax_e32m4();
|
||||
|
||||
#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh)
|
||||
size_t vl = __riscv_vsetvlmax_e32m4();
|
||||
// initialize accumulators to all zeroes
|
||||
vfloat32m4_t vsum0_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
|
||||
vfloat32m4_t vsum0_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
|
||||
vfloat32m4_t vsum1_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
|
||||
vfloat32m4_t vsum1_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
|
||||
|
||||
// initialize accumulators to all zeroes
|
||||
vfloat32m4_t vsum0_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
|
||||
vfloat32m4_t vsum0_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
|
||||
vfloat32m4_t vsum1_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
|
||||
vfloat32m4_t vsum1_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
|
||||
// calculate step size
|
||||
const size_t epr = __riscv_vsetvlmax_e16m2();
|
||||
const size_t step = epr * 2;
|
||||
int np = (n & ~(step - 1));
|
||||
|
||||
// calculate step size
|
||||
const size_t epr = __riscv_vsetvlmax_e16m2();
|
||||
const size_t step = epr * 2;
|
||||
const int np = (n & ~(step - 1));
|
||||
// unroll by 2 along the row dimension
|
||||
for (int i = 0; i < np; i += step) {
|
||||
vfloat16m2_t ay0 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), epr);
|
||||
vfloat16m2_t ax0_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), epr);
|
||||
vfloat16m2_t ax1_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), epr);
|
||||
vsum0_0 = __riscv_vfwmacc_vv_f32m4(vsum0_0, ax0_0, ay0, epr);
|
||||
vsum1_0 = __riscv_vfwmacc_vv_f32m4(vsum1_0, ax1_0, ay0, epr);
|
||||
|
||||
// unroll by 2 along the row dimension
|
||||
for (int i = 0; i < np; i += step) {
|
||||
vfloat16m2_t ay0 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), epr);
|
||||
vfloat16m2_t ax0_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), epr);
|
||||
vfloat16m2_t ax1_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), epr);
|
||||
vsum0_0 = __riscv_vfwmacc_vv_f32m4(vsum0_0, ax0_0, ay0, epr);
|
||||
vsum1_0 = __riscv_vfwmacc_vv_f32m4(vsum1_0, ax1_0, ay0, epr);
|
||||
vfloat16m2_t ay1 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i + epr), epr);
|
||||
vfloat16m2_t ax0_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i + epr), epr);
|
||||
vfloat16m2_t ax1_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i + epr), epr);
|
||||
vsum0_1 = __riscv_vfwmacc_vv_f32m4(vsum0_1, ax0_1, ay1, epr);
|
||||
vsum1_1 = __riscv_vfwmacc_vv_f32m4(vsum1_1, ax1_1, ay1, epr);
|
||||
}
|
||||
|
||||
vfloat16m2_t ay1 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i + epr), epr);
|
||||
vfloat16m2_t ax0_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i + epr), epr);
|
||||
vfloat16m2_t ax1_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i + epr), epr);
|
||||
vsum0_1 = __riscv_vfwmacc_vv_f32m4(vsum0_1, ax0_1, ay1, epr);
|
||||
vsum1_1 = __riscv_vfwmacc_vv_f32m4(vsum1_1, ax1_1, ay1, epr);
|
||||
}
|
||||
vfloat32m4_t vsum0 = __riscv_vfadd_vv_f32m4(vsum0_0, vsum0_1, vl);
|
||||
vfloat32m4_t vsum1 = __riscv_vfadd_vv_f32m4(vsum1_0, vsum1_1, vl);
|
||||
|
||||
vfloat32m4_t vsum0 = __riscv_vfadd_vv_f32m4(vsum0_0, vsum0_1, vl);
|
||||
vfloat32m4_t vsum1 = __riscv_vfadd_vv_f32m4(vsum1_0, vsum1_1, vl);
|
||||
// leftovers
|
||||
for (int i = np; i < n; i += vl) {
|
||||
vl = __riscv_vsetvl_e16m2(n - i);
|
||||
vfloat16m2_t ay = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), vl);
|
||||
vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), vl);
|
||||
vfloat16m2_t ax1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), vl);
|
||||
|
||||
// leftovers
|
||||
for (int i = np; i < n; i += vl) {
|
||||
vl = __riscv_vsetvl_e16m2(n - i);
|
||||
vfloat16m2_t ay = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), vl);
|
||||
vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), vl);
|
||||
vfloat16m2_t ax1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), vl);
|
||||
vsum0 = __riscv_vfwmacc_vv_f32m4(vsum0, ax0, ay, vl);
|
||||
vsum1 = __riscv_vfwmacc_vv_f32m4(vsum1, ax1, ay, vl);
|
||||
}
|
||||
|
||||
vsum0 = __riscv_vfwmacc_vv_f32m4(vsum0, ax0, ay, vl);
|
||||
vsum1 = __riscv_vfwmacc_vv_f32m4(vsum1, ax1, ay, vl);
|
||||
}
|
||||
|
||||
// reduce
|
||||
vl = __riscv_vsetvlmax_e32m2();
|
||||
vfloat32m2_t acc0_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum0, 0),
|
||||
__riscv_vget_v_f32m4_f32m2(vsum0, 1), vl);
|
||||
vl = __riscv_vsetvlmax_e32m1();
|
||||
vfloat32m1_t acc0_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc0_0, 0),
|
||||
__riscv_vget_v_f32m2_f32m1(acc0_0, 1), vl);
|
||||
vfloat32m1_t redsum0 = __riscv_vfredusum_vs_f32m1_f32m1(
|
||||
acc0_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
|
||||
|
||||
vl = __riscv_vsetvlmax_e32m2();
|
||||
vfloat32m2_t acc1_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum1, 0),
|
||||
__riscv_vget_v_f32m4_f32m2(vsum1, 1), vl);
|
||||
vl = __riscv_vsetvlmax_e32m1();
|
||||
vfloat32m1_t acc1_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc1_0, 0),
|
||||
__riscv_vget_v_f32m2_f32m1(acc1_0, 1), vl);
|
||||
vfloat32m1_t redsum1 = __riscv_vfredusum_vs_f32m1_f32m1(
|
||||
acc1_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
|
||||
sumf[0] = __riscv_vfmv_f_s_f32m1_f32(redsum0);
|
||||
sumf[1] = __riscv_vfmv_f_s_f32m1_f32(redsum1);
|
||||
// reduce
|
||||
vl = __riscv_vsetvlmax_e32m2();
|
||||
vfloat32m2_t acc0_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum0, 0),
|
||||
__riscv_vget_v_f32m4_f32m2(vsum0, 1), vl);
|
||||
vl = __riscv_vsetvlmax_e32m1();
|
||||
vfloat32m1_t acc0_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc0_0, 0),
|
||||
__riscv_vget_v_f32m2_f32m1(acc0_0, 1), vl);
|
||||
vfloat32m1_t redsum0 = __riscv_vfredusum_vs_f32m1_f32m1(
|
||||
acc0_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
|
||||
|
||||
vl = __riscv_vsetvlmax_e32m2();
|
||||
vfloat32m2_t acc1_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum1, 0),
|
||||
__riscv_vget_v_f32m4_f32m2(vsum1, 1), vl);
|
||||
vl = __riscv_vsetvlmax_e32m1();
|
||||
vfloat32m1_t acc1_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc1_0, 0),
|
||||
__riscv_vget_v_f32m2_f32m1(acc1_0, 1), vl);
|
||||
vfloat32m1_t redsum1 = __riscv_vfredusum_vs_f32m1_f32m1(
|
||||
acc1_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
|
||||
sumf[0] = __riscv_vfmv_f_s_f32m1_f32(redsum0);
|
||||
sumf[1] = __riscv_vfmv_f_s_f32m1_f32(redsum1);
|
||||
np = n;
|
||||
#else
|
||||
const int np = 0;
|
||||
#endif
|
||||
#else
|
||||
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||
|
||||
|
|
@ -313,21 +317,17 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
|
|||
for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
|
||||
GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
|
||||
}
|
||||
|
||||
// leftovers
|
||||
for (int i = np; i < n; ++i) {
|
||||
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
|
||||
sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
for (int i = 0; i < n; ++i) {
|
||||
// scalar path
|
||||
const int np = 0;
|
||||
#endif
|
||||
// scalar and leftovers
|
||||
for (int i = np; i < n; ++i) {
|
||||
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
|
||||
sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
|
||||
s[i] = (float)sumf[i];
|
||||
|
|
@ -532,40 +532,45 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y,
|
|||
svst1_f16(pg, (__fp16 *)(y + np2), hy);
|
||||
}
|
||||
np = n;
|
||||
#elif defined(__riscv_zvfh) // implies __riscv_v_intrinsic
|
||||
const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v);
|
||||
const _Float16 scale = *(const _Float16*)(&s);
|
||||
#elif defined(__riscv_v_intrinsic) // implies __riscv_v_intrinsic
|
||||
#if defined (__riscv_zvfh)
|
||||
const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v);
|
||||
const _Float16 scale = *(const _Float16*)(&s);
|
||||
|
||||
// calculate step size
|
||||
const int epr = __riscv_vsetvlmax_e16m4();
|
||||
const int step = epr * 2;
|
||||
int np = (n & ~(step - 1));
|
||||
// calculate step size
|
||||
const int epr = __riscv_vsetvlmax_e16m4();
|
||||
const int step = epr * 2;
|
||||
int np = (n & ~(step - 1));
|
||||
|
||||
// unroll by 2
|
||||
for (int i = 0; i < np; i += step) {
|
||||
vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, epr);
|
||||
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr);
|
||||
ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, epr);
|
||||
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr);
|
||||
__asm__ __volatile__ ("" ::: "memory");
|
||||
// unroll by 2
|
||||
for (int i = 0; i < np; i += step) {
|
||||
vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, epr);
|
||||
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr);
|
||||
ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, epr);
|
||||
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr);
|
||||
__asm__ __volatile__ ("" ::: "memory");
|
||||
|
||||
vfloat16m4_t ax1 = __riscv_vle16_v_f16m4((const _Float16*)x + i + epr, epr);
|
||||
vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr);
|
||||
ay1 = __riscv_vfmacc_vf_f16m4(ay1, scale, ax1, epr);
|
||||
__riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr);
|
||||
__asm__ __volatile__ ("" ::: "memory");
|
||||
}
|
||||
vfloat16m4_t ax1 = __riscv_vle16_v_f16m4((const _Float16*)x + i + epr, epr);
|
||||
vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr);
|
||||
ay1 = __riscv_vfmacc_vf_f16m4(ay1, scale, ax1, epr);
|
||||
__riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr);
|
||||
__asm__ __volatile__ ("" ::: "memory");
|
||||
}
|
||||
|
||||
// leftovers
|
||||
int vl;
|
||||
for (int i = np; i < n; i += vl) {
|
||||
vl = __riscv_vsetvl_e16m4(n - i);
|
||||
vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, vl);
|
||||
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl);
|
||||
ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, vl);
|
||||
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl);
|
||||
}
|
||||
np = n;
|
||||
// leftovers
|
||||
int vl;
|
||||
for (int i = np; i < n; i += vl) {
|
||||
vl = __riscv_vsetvl_e16m4(n - i);
|
||||
vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, vl);
|
||||
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl);
|
||||
ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, vl);
|
||||
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl);
|
||||
}
|
||||
np = n;
|
||||
#else
|
||||
// fall to scalar path
|
||||
const int np = 0;
|
||||
#endif
|
||||
#elif defined(GGML_SIMD)
|
||||
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||
|
||||
|
|
@ -584,10 +589,11 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y,
|
|||
}
|
||||
}
|
||||
#else
|
||||
// scalar path
|
||||
const int np = 0;
|
||||
#endif
|
||||
|
||||
// leftovers
|
||||
// scalar and leftovers
|
||||
for (int i = np; i < n; ++i) {
|
||||
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
|
||||
}
|
||||
|
|
@ -785,7 +791,7 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float
|
|||
const int ggml_f16_step = 2 * ggml_f16_epr;
|
||||
|
||||
GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
|
||||
const int np = (n & ~(ggml_f16_step - 1));
|
||||
int np = (n & ~(ggml_f16_step - 1));
|
||||
svfloat16_t ay1, ay2;
|
||||
|
||||
for (int i = 0; i < np; i += ggml_f16_step) {
|
||||
|
|
@ -805,36 +811,43 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float
|
|||
svfloat16_t out = svmul_f16_m(pg, hy, vx);
|
||||
svst1_f16(pg, (__fp16 *)(y + np), out);
|
||||
}
|
||||
#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh)
|
||||
const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v);
|
||||
const _Float16 scale = *(const _Float16*)(&s);
|
||||
np = n;
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
#if defined(__riscv_zvfh)
|
||||
const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v);
|
||||
const _Float16 scale = *(const _Float16*)(&s);
|
||||
|
||||
// calculate step size
|
||||
const int epr = __riscv_vsetvlmax_e16m4();
|
||||
const int step = epr * 2;
|
||||
const int np = (n & ~(step - 1));
|
||||
// calculate step size
|
||||
const int epr = __riscv_vsetvlmax_e16m4();
|
||||
const int step = epr * 2;
|
||||
int np = (n & ~(step - 1));
|
||||
|
||||
// unroll by 2
|
||||
for (int i = 0; i < np; i += step) {
|
||||
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr);
|
||||
ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, epr);
|
||||
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr);
|
||||
__asm__ __volatile__ ("" ::: "memory");
|
||||
// unroll by 2
|
||||
for (int i = 0; i < np; i += step) {
|
||||
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr);
|
||||
ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, epr);
|
||||
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr);
|
||||
__asm__ __volatile__ ("" ::: "memory");
|
||||
|
||||
vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr);
|
||||
ay1 = __riscv_vfmul_vf_f16m4(ay1, scale, epr);
|
||||
__riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr);
|
||||
__asm__ __volatile__ ("" ::: "memory");
|
||||
}
|
||||
vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr);
|
||||
ay1 = __riscv_vfmul_vf_f16m4(ay1, scale, epr);
|
||||
__riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr);
|
||||
__asm__ __volatile__ ("" ::: "memory");
|
||||
}
|
||||
|
||||
// leftovers
|
||||
int vl;
|
||||
for (int i = np; i < n; i += vl) {
|
||||
vl = __riscv_vsetvl_e16m4(n - i);
|
||||
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl);
|
||||
ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, vl);
|
||||
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl);
|
||||
}
|
||||
// leftovers
|
||||
int vl;
|
||||
for (int i = np; i < n; i += vl) {
|
||||
vl = __riscv_vsetvl_e16m4(n - i);
|
||||
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl);
|
||||
ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, vl);
|
||||
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl);
|
||||
}
|
||||
np = n;
|
||||
#else
|
||||
// fall to scalar path
|
||||
const int np = 0;
|
||||
#endif
|
||||
#elif defined(GGML_SIMD)
|
||||
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||
|
||||
|
|
@ -850,17 +863,14 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float
|
|||
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
|
||||
}
|
||||
}
|
||||
|
||||
// leftovers
|
||||
#else
|
||||
// scalar path
|
||||
const int np = 0;
|
||||
#endif
|
||||
// scalar and leftovers
|
||||
for (int i = np; i < n; ++i) {
|
||||
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
|
||||
}
|
||||
#else
|
||||
// scalar
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); }
|
||||
|
|
|
|||
|
|
@ -800,19 +800,32 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
|
|||
}
|
||||
|
||||
static __device__ __forceinline__ float ggml_cuda_ue4m3_to_fp32(uint8_t x) {
|
||||
#ifdef FP8_AVAILABLE
|
||||
const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation.
|
||||
#if defined(GGML_USE_HIP) && defined(CDNA3)
|
||||
// ROCm dose not support fp8 in software on devices with fp8 hardware,
|
||||
#if defined(GGML_USE_HIP) && defined(CDNA3) && defined(FP8_AVAILABLE) && HIP_VERSION >= 60200000
|
||||
// ROCm does not support fp8 in software on devices with fp8 hardware,
|
||||
// but CDNA3 supports only e4m3_fnuz (no inf).
|
||||
const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation.
|
||||
const __hip_fp8_e4m3_fnuz xf = *reinterpret_cast<const __hip_fp8_e4m3_fnuz *>(&bits);
|
||||
#else
|
||||
const __nv_fp8_e4m3 xf = *reinterpret_cast<const __nv_fp8_e4m3 *>(&bits);
|
||||
#endif // defined(GGML_USE_HIP) && defined(GGML_USE_HIP)
|
||||
return static_cast<float>(xf) / 2;
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FP8_AVAILABLE
|
||||
#if defined(FP8_AVAILABLE) && !defined(GGML_USE_HIP)
|
||||
const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation.
|
||||
const __nv_fp8_e4m3 xf = *reinterpret_cast<const __nv_fp8_e4m3 *>(&bits);
|
||||
return static_cast<float>(xf) / 2;
|
||||
#else
|
||||
if (x == 0 || (x == 0x7F && x != 0xFF)) { // Convert NaN to 0.0f
|
||||
return 0.0f;
|
||||
}
|
||||
const int exp = (x >> 3) & 0xF;
|
||||
const int man = x & 0x7;
|
||||
float raw;
|
||||
if (exp == 0) {
|
||||
raw = ldexpf((float) man, -9);
|
||||
} else {
|
||||
raw = ldexpf(1.0f + (float) man / 8.0f, exp - 7);
|
||||
}
|
||||
return static_cast<float>(raw / 2);
|
||||
#endif // defined(FP8_AVAILABLE) && !defined(GGML_USE_HIP)
|
||||
#endif // defined(GGML_USE_HIP) && defined(CDNA3) && defined(FP8_AVAILABLE) && HIP_VERSION >= 60200000
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) {
|
||||
|
|
|
|||
|
|
@ -66,6 +66,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
|||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
|
||||
|
|
@ -80,6 +85,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
|||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
|
||||
|
|
@ -89,6 +99,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
|||
}
|
||||
|
||||
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 64, 1, false);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false);
|
||||
|
|
@ -103,6 +118,10 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
|||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 128, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false);
|
||||
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
|
||||
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
|
||||
|
|
@ -1552,7 +1571,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE))
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
|
||||
if (use_logit_softcap && !(DKQ == 128 || DKQ == 256 || DKQ == 512)) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
|
@ -1815,6 +1834,15 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 64)
|
|||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 64)
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
|
||||
|
||||
extern DECL_FATTN_MMA_F16_CASE(512, 512, 2, 4);
|
||||
extern DECL_FATTN_MMA_F16_CASE(512, 512, 4, 4);
|
||||
extern DECL_FATTN_MMA_F16_CASE(512, 512, 8, 4);
|
||||
extern DECL_FATTN_MMA_F16_CASE(512, 512, 16, 4);
|
||||
extern DECL_FATTN_MMA_F16_CASE(512, 512, 1, 8);
|
||||
extern DECL_FATTN_MMA_F16_CASE(512, 512, 2, 8);
|
||||
extern DECL_FATTN_MMA_F16_CASE(512, 512, 4, 8);
|
||||
extern DECL_FATTN_MMA_F16_CASE(512, 512, 8, 8);
|
||||
|
||||
// The number of viable configurations for Deepseek is very limited:
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
|
||||
|
|
|
|||
|
|
@ -38,6 +38,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor
|
|||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst);
|
||||
} break;
|
||||
case 512: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case<512, 512>(ctx, dst);
|
||||
} break;
|
||||
case 576: {
|
||||
GGML_ASSERT(V->ne[0] == 512);
|
||||
ggml_cuda_flash_attn_ext_tile_case<576, 512>(ctx, dst);
|
||||
|
|
|
|||
|
|
@ -68,6 +68,10 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
|||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
|
|
@ -124,6 +128,10 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
|||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 32, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
|
||||
|
|
@ -187,6 +195,11 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
|||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 32, 512, 1, 128, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
|
|
@ -251,6 +264,11 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
|||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 4, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 32, 256, 2, 128, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
|
||||
|
|
@ -767,7 +785,7 @@ static __global__ void flash_attn_tile(
|
|||
#ifdef GGML_USE_WMMA_FATTN
|
||||
(ncols2 != 1 && DV != 40 && DV != 72 && DV != 512) ||
|
||||
#endif // GGML_USE_WMMA_FATTN
|
||||
(use_logit_softcap && !(DV == 128 || DV == 256))
|
||||
(use_logit_softcap && !(DV == 128 || DV == 256 || DV == 512))
|
||||
) {
|
||||
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
|
||||
max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
|
|
@ -1192,7 +1210,7 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
|
|||
const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX;
|
||||
const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
|
||||
if constexpr (DV == 512) {
|
||||
if constexpr (DKQ == 576) {
|
||||
if (use_gqa_opt && gqa_ratio % 16 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
|
|
@ -1203,7 +1221,7 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
|
|||
}
|
||||
}
|
||||
|
||||
if constexpr (DV <= 256) {
|
||||
if constexpr (DKQ <= 512) {
|
||||
if (use_gqa_opt && gqa_ratio % 8 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
|
|
@ -1214,13 +1232,15 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
|
|||
return;
|
||||
}
|
||||
|
||||
if (use_gqa_opt && gqa_ratio % 2 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst);
|
||||
if constexpr (DV <= 256) {
|
||||
if (use_gqa_opt && gqa_ratio % 2 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 1, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 1, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
|
@ -1255,4 +1275,5 @@ extern DECL_FATTN_TILE_CASE( 96, 96);
|
|||
extern DECL_FATTN_TILE_CASE(112, 112);
|
||||
extern DECL_FATTN_TILE_CASE(128, 128);
|
||||
extern DECL_FATTN_TILE_CASE(256, 256);
|
||||
extern DECL_FATTN_TILE_CASE(512, 512);
|
||||
extern DECL_FATTN_TILE_CASE(576, 512);
|
||||
|
|
|
|||
|
|
@ -135,6 +135,10 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
|||
GGML_ASSERT(V->ne[0] == 256);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
|
||||
break;
|
||||
case 512:
|
||||
GGML_ASSERT(V->ne[0] == 512);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<512, 512>(ctx, dst);
|
||||
break;
|
||||
case 576: {
|
||||
// For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
|
||||
GGML_ASSERT(V->ne[0] == 512);
|
||||
|
|
@ -336,7 +340,8 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|||
case 128:
|
||||
case 112:
|
||||
case 256:
|
||||
if (V->ne[0] != K->ne[0]) {
|
||||
case 512:
|
||||
if (!gqa_opt_applies) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
|
|
@ -424,7 +429,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|||
}
|
||||
|
||||
// Use the WMMA kernel if possible:
|
||||
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 576) {
|
||||
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 512 && Q->ne[0] != 576) {
|
||||
if (can_use_vector_kernel && Q->ne[1] <= 2) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
|
|
@ -457,7 +462,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|||
}
|
||||
|
||||
// Use MFMA flash attention for CDNA (MI100+):
|
||||
if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 576) {
|
||||
if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 512 && Q->ne[0] != 576) {
|
||||
const int64_t eff_nq = Q->ne[1] * (gqa_opt_applies ? gqa_ratio : 1);
|
||||
// MMA vs tile crossover benchmarked on MI300X @ d32768:
|
||||
// hsk=64 (gqa=4): MMA wins at eff >= 128 (+11%)
|
||||
|
|
|
|||
|
|
@ -4791,9 +4791,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_MXFP4:
|
||||
#ifdef FP8_AVAILABLE
|
||||
case GGML_TYPE_NVFP4:
|
||||
#endif // FP8_AVAILABLE
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
|
|
|
|||
|
|
@ -23,6 +23,9 @@ static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, con
|
|||
case GGML_TYPE_MXFP4:
|
||||
mul_mat_q_case<GGML_TYPE_MXFP4>(ctx, args, stream);
|
||||
break;
|
||||
case GGML_TYPE_NVFP4:
|
||||
mul_mat_q_case<GGML_TYPE_NVFP4>(ctx, args, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream);
|
||||
break;
|
||||
|
|
@ -273,6 +276,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
|
|||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_MXFP4:
|
||||
case GGML_TYPE_NVFP4:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
|
|
@ -362,5 +366,4 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
|
|||
}
|
||||
|
||||
return (!GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -68,6 +68,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
|
|||
return MMQ_Q8_1_DS_LAYOUT_D4;
|
||||
case GGML_TYPE_MXFP4:
|
||||
return MMQ_Q8_1_DS_LAYOUT_D4;
|
||||
case GGML_TYPE_NVFP4:
|
||||
return MMQ_Q8_1_DS_LAYOUT_D4;
|
||||
case GGML_TYPE_Q2_K:
|
||||
return MMQ_Q8_1_DS_LAYOUT_D2S6;
|
||||
case GGML_TYPE_Q3_K:
|
||||
|
|
@ -189,6 +191,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
|
|||
case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
|
||||
case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
|
||||
case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1;
|
||||
case GGML_TYPE_NVFP4: return MMQ_DP4A_TXS_Q8_0_16;
|
||||
case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
|
||||
case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
|
||||
case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
|
||||
|
|
@ -206,12 +209,13 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
|
|||
}
|
||||
}
|
||||
|
||||
#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
|
||||
#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4)
|
||||
#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
|
||||
#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
|
||||
#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
|
||||
#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7)
|
||||
#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
|
||||
#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) // MXFP4
|
||||
#define MMQ_MMA_TILE_X_K_NVFP4 (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) // NVFP4
|
||||
#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
|
||||
#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
|
||||
#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
|
||||
#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7)
|
||||
|
||||
static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
|
||||
static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
|
||||
|
|
@ -220,6 +224,8 @@ static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
|
|||
static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
|
||||
static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding.");
|
||||
static_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, "Wrong tile size for MXFP4");
|
||||
static_assert(MMQ_MMA_TILE_X_K_NVFP4 % 8 == 4, "Wrong padding.");
|
||||
|
||||
|
||||
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
||||
switch (type) {
|
||||
|
|
@ -230,6 +236,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
|||
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
||||
// tile sizes are the same for Q8_1 and FP4 for blackwell
|
||||
case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
|
||||
case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_NVFP4;
|
||||
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
|
||||
case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
|
||||
case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
|
||||
|
|
@ -826,6 +833,65 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
template <int mmq_y, bool need_check>
|
||||
static __device__ __forceinline__ void load_tiles_nvfp4(const char * __restrict__ x,
|
||||
int * __restrict__ x_tile,
|
||||
const int kb0,
|
||||
const int i_max,
|
||||
const int stride) {
|
||||
constexpr int nwarps = mmq_get_nwarps_device();
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
|
||||
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
int * x_qs = (int *) x_tile;
|
||||
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
|
||||
#else
|
||||
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_NVFP4, mmq_y);
|
||||
int * x_qs = (int *) x_tile;
|
||||
float * x_df = (float *) (x_qs + txs.qs);
|
||||
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
|
||||
constexpr int threads_per_row = MMQ_ITER_K / QK_NVFP4;
|
||||
constexpr int rows_per_warp = warp_size / threads_per_row;
|
||||
const int kbx = threadIdx.x % threads_per_row;
|
||||
const int row_in_warp = threadIdx.x / threads_per_row;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
|
||||
int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
|
||||
|
||||
if constexpr (need_check) {
|
||||
i = min(i, i_max);
|
||||
}
|
||||
|
||||
const block_nvfp4 * bxi = (const block_nvfp4 *) x + kb0 + i * stride + kbx;
|
||||
const uint32_t * __restrict__ src_qs = reinterpret_cast<const uint32_t *>(bxi->qs);
|
||||
const int kqs = 16 * kbx;
|
||||
const int ksc = 4 * kbx;
|
||||
|
||||
#pragma unroll
|
||||
for (int sub = 0; sub < QK_NVFP4 / QK_NVFP4_SUB; ++sub) {
|
||||
const int2 q0 = get_int_from_table_16(src_qs[2 * sub + 0], kvalues_mxfp4);
|
||||
const int2 q1 = get_int_from_table_16(src_qs[2 * sub + 1], kvalues_mxfp4);
|
||||
|
||||
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 0] = q0.x;
|
||||
x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 1] = q1.x;
|
||||
x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 2] = q0.y;
|
||||
x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 3] = q1.y;
|
||||
x_df[i * MMQ_MMA_TILE_X_K_NVFP4 + ksc + sub] = ggml_cuda_ue4m3_to_fp32(bxi->d[sub]);
|
||||
#else
|
||||
x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 0] = q0.x;
|
||||
x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 1] = q1.x;
|
||||
x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 2] = q0.y;
|
||||
x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 3] = q1.y;
|
||||
x_df[i * (2 * MMQ_TILE_NE_K * 2 / QI_NVFP4) + i / (QK_NVFP4_SUB / QI_NVFP4) + ksc + sub] = ggml_cuda_ue4m3_to_fp32(bxi->d[sub]);
|
||||
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y>
|
||||
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
|
||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
||||
|
|
@ -1229,7 +1295,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
|||
#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
}
|
||||
|
||||
// Used for Q3_K, IQ2_S, and IQ2_XS
|
||||
// Used for NVFP4, Q3_K, IQ2_S, and IQ2_XS
|
||||
template <int mmq_x, int mmq_y>
|
||||
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
|
||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
||||
|
|
@ -3261,6 +3327,14 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
|
|||
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_NVFP4> {
|
||||
static constexpr int vdr = VDR_NVFP4_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4<mmq_y, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
|
||||
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
|
||||
static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
|
||||
|
|
@ -4069,6 +4143,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
|
|||
extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_MXFP4);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_NVFP4);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
|
||||
|
|
|
|||
|
|
@ -235,30 +235,33 @@ static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna4(ggml_type
|
|||
// Host function: returns the max batch size for the current arch+type at runtime.
|
||||
int get_mmvq_mmid_max_batch(ggml_type type, int cc) {
|
||||
// NVIDIA: Volta, Ada Lovelace, and Blackwell always use MMVQ for MUL_MAT_ID.
|
||||
if (cc == GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_ADA_LOVELACE) {
|
||||
return MMVQ_MAX_BATCH_SIZE;
|
||||
}
|
||||
if (cc >= GGML_CUDA_CC_TURING) {
|
||||
return get_mmvq_mmid_max_batch_turing_plus(type);
|
||||
}
|
||||
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
||||
if (cc == GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_ADA_LOVELACE) {
|
||||
return MMVQ_MAX_BATCH_SIZE;
|
||||
}
|
||||
if (cc >= GGML_CUDA_CC_TURING) {
|
||||
return get_mmvq_mmid_max_batch_turing_plus(type);
|
||||
}
|
||||
return get_mmvq_mmid_max_batch_pascal_older(type);
|
||||
}
|
||||
|
||||
// AMD
|
||||
if (GGML_CUDA_CC_IS_RDNA4(cc)) {
|
||||
return get_mmvq_mmid_max_batch_rdna4(type);
|
||||
}
|
||||
if (GGML_CUDA_CC_IS_RDNA3(cc)) {
|
||||
return get_mmvq_mmid_max_batch_rdna3(type);
|
||||
}
|
||||
if (GGML_CUDA_CC_IS_RDNA1(cc) || GGML_CUDA_CC_IS_RDNA2(cc)) {
|
||||
return get_mmvq_mmid_max_batch_rdna1_rdna2(type);
|
||||
}
|
||||
if (GGML_CUDA_CC_IS_CDNA(cc)) {
|
||||
return get_mmvq_mmid_max_batch_cdna(type);
|
||||
}
|
||||
if (GGML_CUDA_CC_IS_GCN(cc)) {
|
||||
return get_mmvq_mmid_max_batch_gcn(type);
|
||||
if (GGML_CUDA_CC_IS_AMD(cc)) {
|
||||
if (GGML_CUDA_CC_IS_RDNA4(cc)) {
|
||||
return get_mmvq_mmid_max_batch_rdna4(type);
|
||||
}
|
||||
if (GGML_CUDA_CC_IS_RDNA3(cc)) {
|
||||
return get_mmvq_mmid_max_batch_rdna3(type);
|
||||
}
|
||||
if (GGML_CUDA_CC_IS_RDNA1(cc) || GGML_CUDA_CC_IS_RDNA2(cc)) {
|
||||
return get_mmvq_mmid_max_batch_rdna1_rdna2(type);
|
||||
}
|
||||
if (GGML_CUDA_CC_IS_CDNA(cc)) {
|
||||
return get_mmvq_mmid_max_batch_cdna(type);
|
||||
}
|
||||
if (GGML_CUDA_CC_IS_GCN(cc)) {
|
||||
return get_mmvq_mmid_max_batch_gcn(type);
|
||||
}
|
||||
}
|
||||
return MMVQ_MAX_BATCH_SIZE;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8);
|
|||
DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(512, 512, 1, 8);
|
||||
|
|
|
|||
|
|
@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
|
|||
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(512, 512, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
|
||||
|
|
|
|||
|
|
@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
|
|||
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(512, 512, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);
|
||||
|
|
|
|||
|
|
@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8);
|
|||
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(512, 512, 2, 8);
|
||||
|
|
|
|||
|
|
@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
|
|||
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(512, 512, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
||||
|
|
|
|||
|
|
@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8);
|
|||
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(512, 512, 4, 8);
|
||||
|
|
|
|||
|
|
@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
|
|||
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(512, 512, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
||||
|
|
|
|||
|
|
@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 8);
|
|||
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 8);
|
||||
DECL_FATTN_MMA_F16_CASE(512, 512, 8, 8);
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(512, 512);
|
||||
|
|
@ -3,7 +3,7 @@
|
|||
from glob import glob
|
||||
import os
|
||||
|
||||
HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576]
|
||||
HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 512, 576]
|
||||
|
||||
TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"]
|
||||
|
||||
|
|
@ -35,7 +35,7 @@ TYPES_MMQ = [
|
|||
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
|
||||
"GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
|
||||
"GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S",
|
||||
"GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_MXFP4"
|
||||
"GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_MXFP4", "GGML_TYPE_NVFP4"
|
||||
]
|
||||
|
||||
SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
|
@ -83,6 +83,8 @@ for ncols in [8, 16, 32, 64]:
|
|||
continue
|
||||
if head_size_kq == 72:
|
||||
continue
|
||||
if head_size_kq == 512 and ncols2 not in (4, 8):
|
||||
continue
|
||||
if head_size_kq != 576 and ncols2 in (16, 32):
|
||||
continue
|
||||
if head_size_kq == 576 and ncols2 not in (4, 16, 32):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmq.cuh"
|
||||
|
||||
DECL_MMQ_CASE(GGML_TYPE_NVFP4);
|
||||
|
|
@ -16,8 +16,10 @@
|
|||
|
||||
#if __HVX_ARCH__ < 79
|
||||
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
|
||||
#define HVX_OP_MUL_F16(a, b) Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(a, b))
|
||||
#else
|
||||
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
|
||||
#define HVX_OP_MUL_F16(a, b) Q6_Vhf_vmpy_VhfVhf(a, b)
|
||||
#endif
|
||||
|
||||
// Compute div by scaler in f32. Requires first by expanding fp32 to fp16 and converting the result back to fp32.
|
||||
|
|
@ -43,46 +45,67 @@ static inline HVX_Vector hvx_div_mul_f16_const_using_f32(HVX_Vector vec1_hf, HVX
|
|||
return res;
|
||||
}
|
||||
|
||||
#define hvx_div_scaler_f16_loop_body(dst_type, src_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src_type * restrict vsrc = (src_type *) src; \
|
||||
HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \
|
||||
\
|
||||
const uint32_t nvec = n / VLEN_FP16; \
|
||||
const uint32_t nloe = n % VLEN_FP16; \
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; i++) { \
|
||||
HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \
|
||||
vdst[i] = res; \
|
||||
} \
|
||||
if (nloe) { \
|
||||
HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \
|
||||
vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \
|
||||
} \
|
||||
// Variant for <v79: Use pre-computed f16 reciprocal constant
|
||||
static inline HVX_Vector hvx_div_mul_f16_const_using_f16(HVX_Vector vec1_hf, HVX_Vector const_inv_hf) {
|
||||
// Multiply by pre-computed f16 reciprocal constant
|
||||
return HVX_OP_MUL_F16(vec1_hf, const_inv_hf);
|
||||
}
|
||||
|
||||
#define hvx_div_scaler_f16_loop_body(dst_type, src_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src_type * restrict vsrc = (src_type *) src; \
|
||||
\
|
||||
HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \
|
||||
\
|
||||
const uint32_t nvec = n / VLEN_FP16; \
|
||||
const uint32_t nloe = n % VLEN_FP16; \
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; i++) { \
|
||||
HVX_Vector res; \
|
||||
if (__HVX_ARCH__ < 79) { \
|
||||
res = hvx_div_mul_f16_const_using_f16(vsrc[i], val_vec_f16); \
|
||||
} else { \
|
||||
res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \
|
||||
} \
|
||||
vdst[i] = res; \
|
||||
} \
|
||||
if (nloe) { \
|
||||
HVX_Vector res; \
|
||||
if (__HVX_ARCH__ < 79) { \
|
||||
res = hvx_div_mul_f16_const_using_f16(vsrc[i], val_vec_f16); \
|
||||
} else { \
|
||||
res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \
|
||||
} \
|
||||
vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
static inline void hvx_div_scalar_f16_aa(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
|
||||
const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
|
||||
const HVX_Vector val_vec_f16 = hvx_vec_splat_f16(1.0f / val);
|
||||
assert((uintptr_t) dst % 128 == 0);
|
||||
assert((uintptr_t) src % 128 == 0);
|
||||
hvx_div_scaler_f16_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
|
||||
}
|
||||
static inline void hvx_div_scalar_f16_au(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
|
||||
const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
|
||||
const HVX_Vector val_vec_f16 = hvx_vec_splat_f16(1.0f / val);
|
||||
assert((uintptr_t) dst % 128 == 0);
|
||||
hvx_div_scaler_f16_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
|
||||
}
|
||||
static inline void hvx_div_scalar_f16_ua(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
|
||||
const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
|
||||
const HVX_Vector val_vec_f16 = hvx_vec_splat_f16(1.0f / val);
|
||||
assert((uintptr_t) src % 128 == 0);
|
||||
hvx_div_scaler_f16_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
|
||||
}
|
||||
static inline void hvx_div_scalar_f16_uu(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
|
||||
const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
|
||||
const HVX_Vector val_vec_f16 = hvx_vec_splat_f16(1.0f / val);
|
||||
hvx_div_scaler_f16_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
|
|
@ -128,13 +151,25 @@ static inline HVX_Vector hvx_vec_div_f16_using_f32(HVX_Vector vec1, HVX_Vector v
|
|||
return recip;
|
||||
}
|
||||
|
||||
// Hybrid approach: f16 reciprocal for <v79, f32 precision for >=v79
|
||||
static inline HVX_Vector hvx_vec_hybrid_div_f16(HVX_Vector vec1, HVX_Vector vec2, HVX_Vector f32_nan_inf_mask, HVX_Vector f16_nan_inf_mask, HVX_Vector vec_hf_one_1_0) {
|
||||
#if __HVX_ARCH__ < 79
|
||||
// For older architectures, use f16 reciprocal to avoid NaN/-inf issues
|
||||
HVX_Vector vec2_inv = hvx_vec_inverse_f16_guard(vec2, f16_nan_inf_mask);
|
||||
return HVX_OP_MUL_F16(vec1, vec2_inv);
|
||||
#else
|
||||
return hvx_vec_div_f16_using_f32(vec1, vec2, f32_nan_inf_mask, vec_hf_one_1_0);
|
||||
#endif
|
||||
}
|
||||
|
||||
#define hvx_div_f16_loop_body(dst_type, src0_type, src1_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src0_type * restrict vsrc0 = (src0_type *) src0; \
|
||||
src1_type * restrict vsrc1 = (src1_type *) src1; \
|
||||
\
|
||||
const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \
|
||||
const HVX_Vector f32_nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \
|
||||
const HVX_Vector f16_nan_inf_mask = Q6_Vh_vsplat_R(0x7c00); \
|
||||
const HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \
|
||||
\
|
||||
const uint32_t nvec = n / VLEN_FP16; \
|
||||
|
|
@ -144,11 +179,15 @@ static inline HVX_Vector hvx_vec_div_f16_using_f32(HVX_Vector vec1, HVX_Vector v
|
|||
\
|
||||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; i++) { \
|
||||
HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \
|
||||
HVX_Vector res = hvx_vec_hybrid_div_f16(vsrc0[i], vsrc1[i], \
|
||||
f32_nan_inf_mask, f16_nan_inf_mask, \
|
||||
hf_one); \
|
||||
vdst[i] = res; \
|
||||
} \
|
||||
if (nloe) { \
|
||||
HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \
|
||||
HVX_Vector res = hvx_vec_hybrid_div_f16(vsrc0[i], vsrc1[i], \
|
||||
f32_nan_inf_mask, f16_nan_inf_mask, \
|
||||
hf_one); \
|
||||
vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \
|
||||
} \
|
||||
} while(0)
|
||||
|
|
@ -247,5 +286,6 @@ HVX_DIV_DISPATCHER(hvx_div_f32)
|
|||
HVX_DIV_DISPATCHER(hvx_div_f16)
|
||||
|
||||
#undef HVX_OP_MUL_F32
|
||||
#undef HVX_OP_MUL_F16
|
||||
|
||||
#endif // HVX_DIV_H
|
||||
|
|
|
|||
|
|
@ -67,34 +67,61 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
|
|||
uint8_t * restrict pad,
|
||||
const int num_elems,
|
||||
float epsilon) {
|
||||
(void)pad;
|
||||
|
||||
const HVX_Vector * restrict v_src = (HVX_Vector *) src;
|
||||
HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
|
||||
|
||||
HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000);
|
||||
const int nvec = num_elems / VLEN_FP32; // number of full vectors
|
||||
const int nloe = num_elems % VLEN_FP32; // leftover elements
|
||||
|
||||
// Compute sum of squares for full vectors
|
||||
HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000);
|
||||
HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon);
|
||||
|
||||
int step_of_1 = num_elems >> 5;
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < step_of_1; i++) {
|
||||
for (int i = 0; i < nvec; i++) {
|
||||
HVX_Vector v1 = v_src[i];
|
||||
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
|
||||
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
|
||||
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
|
||||
}
|
||||
|
||||
sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); // replicated over all lanes
|
||||
// Handle tail elements using vectorized ops with masking
|
||||
if (nloe > 0) {
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
||||
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
|
||||
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
|
||||
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
|
||||
}
|
||||
|
||||
// Reduce HVX sum
|
||||
sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v));
|
||||
|
||||
HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems);
|
||||
HVX_Vector denom_v = hvx_vec_inverse_f32(t_v);
|
||||
HVX_Vector mean_v = Q6_Vqf32_vmpy_VsfVsf(sum_v, denom_v);
|
||||
HVX_Vector mean_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(mean_v, epsilon_v);
|
||||
|
||||
// Scale full vectors
|
||||
HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(mean_epsilon_v));
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < step_of_1; i++) {
|
||||
for (int i = 0; i < nvec; i++) {
|
||||
HVX_Vector v1 = v_src[i];
|
||||
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v);
|
||||
v_dst[i] = Q6_Vsf_equals_Vqf32(v2);
|
||||
v_dst[i] = Q6_Vsf_equals_Vqf32(v2);
|
||||
}
|
||||
|
||||
// Handle tail elements using vectorized ops with masking
|
||||
if (nloe > 0) {
|
||||
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
||||
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
|
||||
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v);
|
||||
HVX_Vector result = Q6_Vsf_equals_Vqf32(v2);
|
||||
|
||||
// Store with masking to avoid overwriting memory beyond the tensor
|
||||
hvx_vec_store_a(&v_dst[nvec], nloe * 4, result);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@
|
|||
#include "ggml-impl.h"
|
||||
#include "ggml-sycl.h"
|
||||
#include "presets.hpp"
|
||||
#include "type.hpp"
|
||||
#include "sycl_hw.hpp"
|
||||
|
||||
namespace syclexp = sycl::ext::oneapi::experimental;
|
||||
|
|
@ -965,4 +966,10 @@ static T block_reduce(T val, T * shared_vals, int block_size_template) {
|
|||
return val;
|
||||
}
|
||||
|
||||
static __dpct_inline__ float ggml_sycl_ue4m3_to_fp32(uint8_t x) {
|
||||
const uint32_t bits = x * (x != 0x7F && x != 0xFF);
|
||||
const __nv_fp8_e4m3 xf = *reinterpret_cast<const __nv_fp8_e4m3 *>(&bits);
|
||||
return static_cast<float>(xf) / 2;
|
||||
}
|
||||
|
||||
#endif // GGML_SYCL_COMMON_HPP
|
||||
|
|
|
|||
|
|
@ -482,6 +482,18 @@ static void dequantize_row_mxfp4_sycl(const void * vx, dst_t * y, const int64_t
|
|||
});
|
||||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_row_nvfp4_sycl(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(k % QK_NVFP4 == 0);
|
||||
const int nb = k / QK_NVFP4;
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
dequantize_block_nvfp4(vx, y, k);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static void dequantize_block_nc(const void * __restrict__ vx, dst_t * __restrict__ y,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
||||
|
|
@ -641,6 +653,8 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
|
|||
return dequantize_row_iq4_nl_sycl;
|
||||
case GGML_TYPE_MXFP4:
|
||||
return dequantize_row_mxfp4_sycl;
|
||||
case GGML_TYPE_NVFP4:
|
||||
return dequantize_row_nvfp4_sycl;
|
||||
case GGML_TYPE_F32:
|
||||
return convert_unary_sycl<float>;
|
||||
#ifdef GGML_SYCL_HAS_BF16
|
||||
|
|
@ -648,6 +662,7 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
|
|||
return convert_unary_sycl<sycl::ext::oneapi::bfloat16>;
|
||||
#endif
|
||||
default:
|
||||
GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(type));
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
|
@ -708,6 +723,8 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
|
|||
return dequantize_row_iq4_nl_sycl;
|
||||
case GGML_TYPE_MXFP4:
|
||||
return dequantize_row_mxfp4_sycl;
|
||||
case GGML_TYPE_NVFP4:
|
||||
return dequantize_row_nvfp4_sycl;
|
||||
case GGML_TYPE_F16:
|
||||
return convert_unary_sycl<sycl::half>;
|
||||
#ifdef GGML_SYCL_HAS_BF16
|
||||
|
|
@ -715,6 +732,7 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
|
|||
return convert_unary_sycl<sycl::ext::oneapi::bfloat16>;
|
||||
#endif
|
||||
default:
|
||||
GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(type));
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -838,4 +838,36 @@ static void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restr
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_block_nvfp4(
|
||||
const void * __restrict__ vx,
|
||||
dst_t * __restrict__ yy,
|
||||
const int64_t ne) {
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
const int64_t i = item_ct1.get_group(2);
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
|
||||
const int64_t base = i * QK_NVFP4;
|
||||
if (base >= ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const block_nvfp4 * x = (const block_nvfp4 *) vx;
|
||||
const block_nvfp4 & xb = x[i];
|
||||
|
||||
const int sub = tid / (QK_NVFP4_SUB / 2);
|
||||
const int j = tid % (QK_NVFP4_SUB / 2);
|
||||
|
||||
const float d = ggml_sycl_ue4m3_to_fp32(xb.d[sub]);
|
||||
const uint8_t q = xb.qs[sub * (QK_NVFP4_SUB / 2) + j];
|
||||
|
||||
const int64_t y0 = base + sub * QK_NVFP4_SUB + j;
|
||||
const int64_t y1 = y0 + QK_NVFP4_SUB / 2;
|
||||
|
||||
yy[y0] = ggml_sycl_cast<dst_t>(d * kvalues_mxfp4[q & 0x0F]);
|
||||
yy[y1] = ggml_sycl_cast<dst_t>(d * kvalues_mxfp4[q >> 4]);
|
||||
}
|
||||
|
||||
|
||||
#endif // GGML_SYCL_DEQUANTIZE_HPP
|
||||
|
|
|
|||
|
|
@ -613,6 +613,23 @@ static void mul_mat_vec_mxfp4_q8_1_sycl(const void * vx, const void * vy, float
|
|||
}
|
||||
}
|
||||
|
||||
static void mul_mat_vec_nvfp4_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows,
|
||||
dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(ncols % QK_NVFP4 == 0);
|
||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
|
||||
{
|
||||
stream->submit([&](sycl::handler & cgh) {
|
||||
cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK_NVFP4, QI_NVFP4, block_nvfp4, VDR_NVFP4_Q8_1_MMVQ, vec_dot_nvfp4_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
float *dst, const int ncols,
|
||||
|
|
@ -1145,8 +1162,11 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
|
|||
case GGML_TYPE_MXFP4:
|
||||
mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
case GGML_TYPE_NVFP4:
|
||||
mul_mat_vec_nvfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(src0->type));
|
||||
}
|
||||
}
|
||||
GGML_UNUSED(src1);
|
||||
|
|
|
|||
|
|
@ -0,0 +1,112 @@
|
|||
#pragma once
|
||||
|
||||
#include <sycl/sycl.hpp>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
|
||||
inline uint8_t float_to_e4m3(float f)
|
||||
{
|
||||
if (sycl::isnan(f)) {
|
||||
return 0x7F; // Canonical NaN (positive)
|
||||
}
|
||||
|
||||
uint32_t bits = sycl::bit_cast<uint32_t>(f);
|
||||
uint32_t sign = (bits >> 31) & 0x1u;
|
||||
uint32_t exp = (bits >> 23) & 0xFFu;
|
||||
uint32_t mant = bits & 0x7FFFFFu;
|
||||
|
||||
// Zero
|
||||
if (exp == 0 && mant == 0) {
|
||||
return static_cast<uint8_t>(sign << 7);
|
||||
}
|
||||
|
||||
// Extract biased exponent and mantissa for FP8
|
||||
int e = static_cast<int>(exp) - 127; // true exponent (IEEE bias 127)
|
||||
uint32_t m = mant;
|
||||
|
||||
// Handle very large values → NaN (NVIDIA behavior for E4M3)
|
||||
if (e > 7) { // max exponent for E4M3 is 7 (biased 14)
|
||||
return static_cast<uint8_t>((sign << 7) | 0x7F);
|
||||
}
|
||||
|
||||
// Handle subnormals and normal numbers
|
||||
if (e < -6) { // smallest normal exponent is -6
|
||||
// Subnormal in FP8: shift mantissa right
|
||||
int shift = -6 - e;
|
||||
m = (m | 0x800000u) >> (shift + 1); // +1 because we lose the implicit 1 position
|
||||
if (shift > 23) m = 0;
|
||||
} else {
|
||||
// Normal number: adjust exponent bias from 127 to 7
|
||||
int new_exp = e + 7;
|
||||
m = (m >> 20) & 0x7u; // take top 3 mantissa bits (after implicit 1)
|
||||
m |= (static_cast<uint32_t>(new_exp) << 3);
|
||||
}
|
||||
|
||||
// Round-to-nearest-even (simple guard + round bit)
|
||||
// For better accuracy you can add sticky bit, but this is sufficient for most use cases
|
||||
uint32_t round_bit = (mant >> 19) & 0x1u; // bit after the 3 mantissa bits
|
||||
if (round_bit) {
|
||||
m += 1;
|
||||
// Carry into exponent if mantissa overflows
|
||||
if ((m & 0x8u) != 0) {
|
||||
m = (m & 0x7u) | ((m & 0x38u) << 1); // simple carry handling
|
||||
// If exponent overflows after carry → NaN
|
||||
if ((m >> 3) > 14) {
|
||||
return static_cast<uint8_t>((sign << 7) | 0x7F);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
uint8_t result = static_cast<uint8_t>((sign << 7) | (m & 0x7F));
|
||||
return result;
|
||||
}
|
||||
|
||||
inline float e4m3_to_float(uint8_t x)
|
||||
{
|
||||
if (x == 0) return 0.0f;
|
||||
|
||||
uint8_t sign = (x >> 7) & 0x1u;
|
||||
uint8_t exp = (x >> 3) & 0xFu;
|
||||
uint8_t mant = x & 0x7u;
|
||||
|
||||
// NaN (NVIDIA uses 0x7F / 0xFF as NaN)
|
||||
if (exp == 0xF && mant != 0) {
|
||||
return std::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
if (exp == 0xF) { // 0x7F or 0xFF treated as NaN
|
||||
return std::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
|
||||
float val;
|
||||
|
||||
if (exp == 0) {
|
||||
// Subnormal
|
||||
val = mant * (1.0f / 8.0f) * sycl::pow(2.0f, -6.0f);
|
||||
} else {
|
||||
// Normal: implicit leading 1 + bias 7
|
||||
val = (1.0f + mant / 8.0f) * sycl::pow(2.0f, static_cast<float>(exp) - 7.0f);
|
||||
}
|
||||
|
||||
return sign ? -val : val;
|
||||
}
|
||||
|
||||
// The actual type definition
|
||||
struct __nv_fp8_e4m3 {
|
||||
uint8_t raw;
|
||||
|
||||
__nv_fp8_e4m3() = default;
|
||||
|
||||
explicit __nv_fp8_e4m3(float f) : raw(float_to_e4m3(f)) {}
|
||||
explicit __nv_fp8_e4m3(sycl::half h) : raw(float_to_e4m3(static_cast<float>(h))) {}
|
||||
|
||||
operator float() const { return e4m3_to_float(raw); }
|
||||
operator sycl::half() const { return static_cast<sycl::half>(static_cast<float>(*this)); }
|
||||
|
||||
// Allow direct access for vector loads/stores
|
||||
operator uint8_t&() { return raw; }
|
||||
operator uint8_t() const { return raw; }
|
||||
};
|
||||
|
||||
using __nv_fp8x2_e4m3 = sycl::vec<__nv_fp8_e4m3, 2>;
|
||||
using __nv_fp8x4_e4m3 = sycl::vec<__nv_fp8_e4m3, 4>;
|
||||
|
||||
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
#include "dpct/helper.hpp"
|
||||
#include "ggml.h"
|
||||
#include "type.hpp"
|
||||
#include "quants.hpp"
|
||||
|
||||
typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1,
|
||||
|
|
@ -31,6 +32,18 @@ static __dpct_inline__ int get_int_b1(const void * x, const int & i32) {
|
|||
return x32;
|
||||
}
|
||||
|
||||
static __dpct_inline__ int get_int_b2(const void * x, const int & i32) {
|
||||
const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
|
||||
|
||||
int x32 = x16[2*i32 + 0] << 0;
|
||||
x32 |= x16[2*i32 + 1] << 16;
|
||||
|
||||
return x32;
|
||||
}
|
||||
|
||||
static __dpct_inline__ int get_int_b4(const void * x, const int & i32) {
|
||||
return ((const int *) x)[i32]; // assume at least 4 byte alignment
|
||||
}
|
||||
|
||||
static __dpct_inline__ int get_int_from_int8(const int8_t* x8, const int& i32) {
|
||||
const uint16_t* x16 =
|
||||
|
|
@ -755,6 +768,35 @@ static __dpct_inline__ float vec_dot_mxfp4_q8_1(const void * __restrict__ vbq,
|
|||
return d * sumi;
|
||||
}
|
||||
|
||||
#define VDR_NVFP4_Q8_1_MMVQ 4
|
||||
#define VDR_NVFP4_Q8_1_MMQ 8
|
||||
|
||||
static __dpct_inline__ float vec_dot_nvfp4_q8_1(const void * __restrict__ vbq,
|
||||
const block_q8_1 * __restrict__ bq8_1,
|
||||
const int32_t & iqs) {
|
||||
const block_nvfp4 * bq4 = (const block_nvfp4 *) vbq;
|
||||
float sum = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VDR_NVFP4_Q8_1_MMVQ/2; i++) {
|
||||
const int32_t iqs0 = iqs + 2*i;
|
||||
const int32_t iqs1 = iqs0 + 1;
|
||||
const int32_t is = iqs0 >> 1;
|
||||
const sycl::int2 v0 = get_int_from_table_16(get_int_b4(bq4->qs, iqs0), kvalues_mxfp4);
|
||||
const sycl::int2 v1 = get_int_from_table_16(get_int_b4(bq4->qs, iqs1), kvalues_mxfp4);
|
||||
const block_q8_1 * bq8 = bq8_1 + (is >> 1);
|
||||
const int32_t i8 = ((is & 1) << 2);
|
||||
|
||||
int sumi = ggml_sycl_dp4a(v0.x(), get_int_b4(bq8->qs, i8 + 0), 0);
|
||||
sumi = ggml_sycl_dp4a(v0.y(), get_int_b4(bq8->qs, i8 + 2), sumi);
|
||||
sumi = ggml_sycl_dp4a(v1.x(), get_int_b4(bq8->qs, i8 + 1), sumi);
|
||||
sumi = ggml_sycl_dp4a(v1.y(), get_int_b4(bq8->qs, i8 + 3), sumi);
|
||||
|
||||
const float d = ggml_sycl_ue4m3_to_fp32(bq4->d[is]) * (bq8->ds)[0];
|
||||
sum += d * float(sumi);
|
||||
}
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
static __dpct_inline__ float
|
||||
vec_dot_q5_0_q8_1(const void *__restrict__ vbq,
|
||||
|
|
|
|||
|
|
@ -1219,9 +1219,8 @@ class ggml_webgpu_shader_lib {
|
|||
|
||||
defines.push_back("BYTE_HELPERS");
|
||||
defines.push_back("MUL_ACC_" + type_upper);
|
||||
|
||||
// For fast path we always dequantize from f16 inside the shader
|
||||
defines.push_back("SRC0_INNER_TYPE=f16");
|
||||
defines.push_back("U32_DEQUANT_HELPERS");
|
||||
defines.push_back("SRC0_INNER_TYPE=u32");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
@ -1334,9 +1333,8 @@ class ggml_webgpu_shader_lib {
|
|||
defines.push_back("MUL_ACC_" + type_upper);
|
||||
defines.push_back("INIT_SRC0_SHMEM_" + type_upper);
|
||||
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
|
||||
|
||||
// Use f16 inside the shader for quantized types
|
||||
defines.push_back("SRC0_INNER_TYPE=f16");
|
||||
defines.push_back("U32_DEQUANT_HELPERS");
|
||||
defines.push_back("SRC0_INNER_TYPE=u32");
|
||||
|
||||
variant += std::string("_") + src0_name;
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim
|
|||
|
||||
#define WEBGPU_NUM_PARAM_BUFS 96u
|
||||
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 32u
|
||||
#define WEBGPU_WAIT_ANY_TIMEOUT_MS 0
|
||||
#define WEBGPU_WAIT_ANY_TIMEOUT_MS 100
|
||||
// Maximum number of in-flight submissions per-thread, to avoid exhausting the
|
||||
// parameter buffer pool
|
||||
#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD (WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE)
|
||||
|
|
@ -171,6 +171,7 @@ struct webgpu_buf_pool {
|
|||
// Try growing the pool if no free buffers
|
||||
if (free.empty() && cur_pool_size < max_pool_size && should_grow) {
|
||||
cur_pool_size++;
|
||||
lock.unlock(); // avoid deadlock between this lock and Dawn's internal locks when buffers are freed in callbacks
|
||||
wgpu::Buffer dev_buf;
|
||||
ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
|
||||
|
||||
|
|
@ -507,7 +508,7 @@ static void ggml_backend_webgpu_wait(webgpu_global_context & ctx,
|
|||
|
||||
bool blocking_wait = block || subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD;
|
||||
while (blocking_wait) {
|
||||
auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, 0);
|
||||
auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, WEBGPU_WAIT_ANY_TIMEOUT_MS * 1e6);
|
||||
if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) {
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
ggml_backend_webgpu_wait_profile_futures(ctx, subs[0].profile_futures, true);
|
||||
|
|
@ -728,7 +729,6 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
|
|||
ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x);
|
||||
std::vector<webgpu_command> commands = { command };
|
||||
std::vector<webgpu_submission> sub = { ggml_backend_webgpu_submit(ctx, commands, ctx->memset_buf_pool) };
|
||||
ggml_backend_webgpu_wait(ctx, sub);
|
||||
}
|
||||
|
||||
/** End WebGPU Actions */
|
||||
|
|
@ -2694,17 +2694,6 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
|||
// memset the remaining bytes
|
||||
ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32,
|
||||
total_offset + (size - remaining_size), remaining_size);
|
||||
} else {
|
||||
// wait for WriteBuffer to complete
|
||||
buf_ctx->global_ctx->instance.WaitAny(buf_ctx->global_ctx->queue.OnSubmittedWorkDone(
|
||||
wgpu::CallbackMode::AllowSpontaneous,
|
||||
[](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
|
||||
if (status != wgpu::QueueWorkDoneStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n",
|
||||
std::string(message).c_str());
|
||||
}
|
||||
}),
|
||||
UINT64_MAX);
|
||||
}
|
||||
WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, buf_ctx->global_ctx);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,30 @@ fn get_byte_i32(value: u32, index: u32) -> i32 {
|
|||
}
|
||||
#endif
|
||||
|
||||
#ifdef U32_DEQUANT_HELPERS
|
||||
fn load_src0_u16_at(byte_offset: u32) -> u32 {
|
||||
let word = src0[byte_offset / 4u];
|
||||
let shift = (byte_offset & 2u) * 8u;
|
||||
return (word >> shift) & 0xFFFFu;
|
||||
}
|
||||
|
||||
fn load_src0_u32_at(byte_offset: u32) -> u32 {
|
||||
let word_idx = byte_offset / 4u;
|
||||
let shift = (byte_offset & 3u) * 8u;
|
||||
let lo = src0[word_idx];
|
||||
if (shift == 0u) {
|
||||
return lo;
|
||||
}
|
||||
let hi = src0[word_idx + 1u];
|
||||
return (lo >> shift) | (hi << (32u - shift));
|
||||
}
|
||||
|
||||
fn load_src0_f16_at(byte_offset: u32) -> f16 {
|
||||
let packed = unpack2x16float(load_src0_u16_at(byte_offset));
|
||||
return f16(packed[0]);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef Q4_0_T
|
||||
struct q4_0 {
|
||||
d: f16,
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ enable chromium_experimental_subgroup_matrix;
|
|||
|
||||
#ifdef KV_F32
|
||||
#define KV_TYPE f32
|
||||
#elif defined(KV_Q4_0) || defined(KV_Q8_0)
|
||||
#define KV_TYPE u32
|
||||
#else
|
||||
#define KV_TYPE f16
|
||||
#endif
|
||||
|
|
@ -37,11 +39,13 @@ enable chromium_experimental_subgroup_matrix;
|
|||
#define NQ 16
|
||||
// Q4_0 has 32 elements, 1 f16 for scale, 8 f16 for 4-bit weights
|
||||
#define F16_PER_BLOCK 9
|
||||
#define BLOCK_SIZE_BYTES 18u
|
||||
#define WEIGHTS_PER_F16 4
|
||||
#elif defined(KV_Q8_0)
|
||||
#define NQ 8
|
||||
// Q8_0 has 32 elements, 1 f16 for scale, 16 f16 for 8-bit weights
|
||||
#define F16_PER_BLOCK 17
|
||||
#define BLOCK_SIZE_BYTES 34u
|
||||
#define WEIGHTS_PER_F16 2
|
||||
#endif
|
||||
#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16)
|
||||
|
|
@ -55,6 +59,47 @@ fn get_byte_i32(value: u32, index: u32) -> i32 {
|
|||
return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
|
||||
}
|
||||
|
||||
#if defined(KV_Q4_0) || defined(KV_Q8_0)
|
||||
fn load_k_u16_at(byte_offset: u32) -> u32 {
|
||||
let word = K[byte_offset / 4u];
|
||||
let shift = (byte_offset & 2u) * 8u;
|
||||
return (word >> shift) & 0xFFFFu;
|
||||
}
|
||||
|
||||
fn load_k_u32_at(byte_offset: u32) -> u32 {
|
||||
let word_idx = byte_offset / 4u;
|
||||
let shift = (byte_offset & 3u) * 8u;
|
||||
let lo = K[word_idx];
|
||||
if (shift == 0u) {
|
||||
return lo;
|
||||
}
|
||||
let hi = K[word_idx + 1u];
|
||||
return (lo >> shift) | (hi << (32u - shift));
|
||||
}
|
||||
|
||||
fn load_v_u16_at(byte_offset: u32) -> u32 {
|
||||
let word = V[byte_offset / 4u];
|
||||
let shift = (byte_offset & 2u) * 8u;
|
||||
return (word >> shift) & 0xFFFFu;
|
||||
}
|
||||
|
||||
fn load_v_u32_at(byte_offset: u32) -> u32 {
|
||||
let word_idx = byte_offset / 4u;
|
||||
let shift = (byte_offset & 3u) * 8u;
|
||||
let lo = V[word_idx];
|
||||
if (shift == 0u) {
|
||||
return lo;
|
||||
}
|
||||
let hi = V[word_idx + 1u];
|
||||
return (lo >> shift) | (hi << (32u - shift));
|
||||
}
|
||||
|
||||
fn f16_from_u16(bits: u32) -> f16 {
|
||||
let packed = unpack2x16float(bits);
|
||||
return f16(packed[0]);
|
||||
}
|
||||
#endif
|
||||
|
||||
struct Params {
|
||||
offset_q: u32,
|
||||
offset_k: u32,
|
||||
|
|
@ -254,12 +299,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
|
||||
if (global_k_row < params.seq_len_kv) {
|
||||
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
|
||||
let base_idx = global_block_idx * F16_PER_BLOCK;
|
||||
let d = K[base_idx]; // scale
|
||||
let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
|
||||
let d = f16_from_u16(load_k_u16_at(block_byte_base));
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = K[base_idx + 1u + block_offset + j];
|
||||
let q_1 = K[base_idx + 1u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
|
||||
let q_packed = load_k_u32_at(q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
|
||||
|
|
@ -282,12 +326,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
|
||||
if (global_k_row < params.seq_len_kv) {
|
||||
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
|
||||
let base_idx = global_block_idx * F16_PER_BLOCK;
|
||||
let d = K[base_idx]; // scale
|
||||
let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
|
||||
let d = f16_from_u16(load_k_u16_at(block_byte_base));
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = K[base_idx + 1u + block_offset + j];
|
||||
let q_1 = K[base_idx + 1u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
|
||||
let q_packed = load_k_u32_at(q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f16(q_byte) * d;
|
||||
|
|
@ -459,12 +502,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
|
||||
if (global_v_row < params.seq_len_kv) {
|
||||
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
|
||||
let base_idx = global_block_idx * F16_PER_BLOCK;
|
||||
let d = V[base_idx]; // scale
|
||||
let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
|
||||
let d = f16_from_u16(load_v_u16_at(block_byte_base));
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = V[base_idx + 1u + block_offset + j];
|
||||
let q_1 = V[base_idx + 1u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
|
||||
let q_packed = load_v_u32_at(q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
|
||||
|
|
@ -487,12 +529,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
|
||||
if (global_v_row < params.seq_len_kv) {
|
||||
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
|
||||
let base_idx = global_block_idx * F16_PER_BLOCK;
|
||||
let d = V[base_idx]; // scale
|
||||
let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
|
||||
let d = f16_from_u16(load_v_u16_at(block_byte_base));
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = V[base_idx + 1u + block_offset + j];
|
||||
let q_1 = V[base_idx + 1u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
|
||||
let q_packed = load_v_u32_at(q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f16(q_byte) * d;
|
||||
|
|
|
|||
|
|
@ -61,10 +61,10 @@ fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u3
|
|||
|
||||
#ifdef INIT_SRC0_SHMEM_Q4_0
|
||||
const BLOCK_SIZE = 32u;
|
||||
const BLOCK_SIZE_BYTES = 18u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
override BLOCKS_K = TILE_K/BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
|
|
@ -81,14 +81,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
let d = src0[scale_idx];
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = src0[scale_idx + 1u + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
|
||||
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
|
||||
|
|
@ -104,10 +102,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
|
||||
#ifdef INIT_SRC0_SHMEM_Q4_1
|
||||
const BLOCK_SIZE = 32u;
|
||||
const BLOCK_SIZE_BYTES = 20u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
override BLOCKS_K = TILE_K/BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const F16_PER_BLOCK = 10u; // 1 scale + 8 packed weights + 1 mean
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
|
|
@ -124,15 +122,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
let d = src0[scale_idx];
|
||||
let m = src0[scale_idx + 1u];
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let m = load_src0_f16_at(block_byte_base + 2u);
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = src0[scale_idx + 2u + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
|
||||
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_lo = f16(q_byte & 0xF) * d + m;
|
||||
|
|
@ -149,11 +145,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
#ifdef INIT_SRC0_SHMEM_Q5_0
|
||||
// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
|
||||
const BLOCK_SIZE = 32u;
|
||||
const BLOCK_SIZE_BYTES = 22u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
// tile_k is defined as 32u, so blocks_k ends up being 1 always
|
||||
override BLOCKS_K = TILE_K / BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const F16_PER_BLOCK = 11u; // 1 scale + 2 qh + 8 packed weights
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
|
||||
|
||||
|
|
@ -171,18 +167,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = src0[scale_idx];
|
||||
let qh0 = src0[scale_idx + 1u];
|
||||
let qh1 = src0[scale_idx + 2u];
|
||||
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let qh_packed = load_src0_u32_at(block_byte_base + 2u);
|
||||
|
||||
for (var j = 0u; j < 2; j++) {
|
||||
let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];
|
||||
let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];
|
||||
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
|
||||
let j_adjusted = j + (block_offset / 2u);
|
||||
|
||||
|
|
@ -207,11 +199,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
#ifdef INIT_SRC0_SHMEM_Q5_1
|
||||
// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
|
||||
const BLOCK_SIZE = 32u;
|
||||
const BLOCK_SIZE_BYTES = 24u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
// tile_k is defined as 32u, so blocks_k ends up being 1 always
|
||||
override BLOCKS_K = TILE_K / BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const F16_PER_BLOCK = 12u; // 1 scale + 2 qh + 8 packed weights + 1 mean
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
|
||||
|
||||
|
|
@ -229,20 +221,16 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = src0[scale_idx];
|
||||
let m = src0[scale_idx + 1u];
|
||||
let qh0 = src0[scale_idx + 2u];
|
||||
let qh1 = src0[scale_idx + 3u];
|
||||
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let m = load_src0_f16_at(block_byte_base + 2u);
|
||||
let qh_packed = load_src0_u32_at(block_byte_base + 4u);
|
||||
|
||||
for (var j = 0u; j < 2; j++) {
|
||||
|
||||
let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];
|
||||
let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];
|
||||
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
|
||||
let j_adjusted = j + (block_offset / 2u);
|
||||
|
||||
|
|
@ -266,10 +254,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
|
||||
#ifdef INIT_SRC0_SHMEM_Q8_0
|
||||
const BLOCK_SIZE = 32u;
|
||||
const BLOCK_SIZE_BYTES = 34u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
override BLOCKS_K = TILE_K/BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const F16_PER_BLOCK = 17u; // 1 scale + 16 in array of weights
|
||||
const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread
|
||||
|
||||
|
|
@ -286,14 +274,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
let d = src0[scale_idx];
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
|
||||
let q_0 = src0[scale_idx + 1u + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
|
||||
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
|
||||
|
|
@ -308,10 +294,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
|
||||
#ifdef INIT_SRC0_SHMEM_Q8_1
|
||||
const BLOCK_SIZE = 32u;
|
||||
const BLOCK_SIZE_BYTES = 36u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
override BLOCKS_K = TILE_K/BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const F16_PER_BLOCK = 18u; // 1 scale + 1 mean + 8 32-bit values in array of weights
|
||||
const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread, 2 threads per block
|
||||
|
||||
|
|
@ -328,15 +314,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
let d = src0[scale_idx];
|
||||
let m = src0[scale_idx + 1u];
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let m = load_src0_f16_at(block_byte_base + 2u);
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
|
||||
let q_0 = src0[scale_idx + 2u + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
|
||||
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
|
||||
|
|
@ -351,7 +335,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
|
||||
#ifdef INIT_SRC0_SHMEM_Q2_K
|
||||
const BLOCK_SIZE = 256u;
|
||||
const F16_PER_BLOCK = 42u;
|
||||
const BLOCK_SIZE_BYTES = 84u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
// Use standard thread layout instead of lane/row_group
|
||||
|
|
@ -371,10 +355,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = src0[scale_idx + 40u];
|
||||
let dmin = src0[scale_idx + 41u];
|
||||
let d = load_src0_f16_at(block_byte_base + 80u);
|
||||
let dmin = load_src0_f16_at(block_byte_base + 82u);
|
||||
|
||||
// Decode the element at position k_in_block
|
||||
let block_of_32 = k_in_block / 32u;
|
||||
|
|
@ -387,18 +371,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
|
||||
let is = k_in_block / 16u;
|
||||
|
||||
let sc_0 = src0[scale_idx + 2u * (is / 4u)];
|
||||
let sc_1 = src0[scale_idx + 2u * (is / 4u) + 1u];
|
||||
let sc_packed = bitcast<u32>(vec2(sc_0, sc_1));
|
||||
let sc_packed = load_src0_u32_at(block_byte_base + 4u * (is / 4u));
|
||||
let sc = get_byte(sc_packed, is % 4u);
|
||||
|
||||
let dl = d * f16(sc & 0xFu);
|
||||
let ml = dmin * f16(sc >> 4u);
|
||||
|
||||
let q_idx = q_b_idx + k + l;
|
||||
let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
|
||||
let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (q_idx / 4u));
|
||||
let q_byte = get_byte(q_packed, q_idx % 4u);
|
||||
let qs_val = (q_byte >> shift) & 3u;
|
||||
|
||||
|
|
@ -410,7 +390,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
|
||||
#ifdef INIT_SRC0_SHMEM_Q3_K
|
||||
const BLOCK_SIZE = 256u;
|
||||
const F16_PER_BLOCK = 55u;
|
||||
const BLOCK_SIZE_BYTES = 110u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
|
|
@ -429,9 +409,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = src0[scale_idx + 54u];
|
||||
let d = load_src0_f16_at(block_byte_base + 108u);
|
||||
|
||||
// Load and unpack scales
|
||||
let kmask1: u32 = 0x03030303u;
|
||||
|
|
@ -439,9 +419,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
|
||||
var scale_vals: array<u32, 4>;
|
||||
for (var i: u32 = 0u; i < 4u; i++) {
|
||||
let scale_0 = src0[scale_idx + 48u + (2u*i)];
|
||||
let scale_1 = src0[scale_idx + 48u + (2u*i) + 1u];
|
||||
scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
|
||||
scale_vals[i] = load_src0_u32_at(block_byte_base + 96u + 4u * i);
|
||||
}
|
||||
|
||||
var tmp: u32 = scale_vals[2];
|
||||
|
|
@ -453,16 +431,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
// Load hmask and qs arrays
|
||||
var hmask_vals: array<u32, 8>;
|
||||
for (var i: u32 = 0u; i < 8u; i++) {
|
||||
let hmask_0 = src0[scale_idx + (2u*i)];
|
||||
let hmask_1 = src0[scale_idx + (2u*i) + 1u];
|
||||
hmask_vals[i] = bitcast<u32>(vec2(hmask_0, hmask_1));
|
||||
hmask_vals[i] = load_src0_u32_at(block_byte_base + 4u * i);
|
||||
}
|
||||
|
||||
var qs_vals: array<u32, 16>;
|
||||
for (var i: u32 = 0u; i < 16u; i++) {
|
||||
let qs_0 = src0[scale_idx + 16u + (2u*i)];
|
||||
let qs_1 = src0[scale_idx + 16u + (2u*i) + 1u];
|
||||
qs_vals[i] = bitcast<u32>(vec2(qs_0, qs_1));
|
||||
qs_vals[i] = load_src0_u32_at(block_byte_base + 32u + 4u * i);
|
||||
}
|
||||
|
||||
let half = k_in_block / 128u; // 0 or 1
|
||||
|
|
@ -502,7 +476,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
|
||||
#ifdef INIT_SRC0_SHMEM_Q4_K
|
||||
const BLOCK_SIZE = 256u;
|
||||
const F16_PER_BLOCK = 72u;
|
||||
const BLOCK_SIZE_BYTES = 144u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
|
|
@ -521,17 +495,15 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = src0[scale_idx];
|
||||
let dmin = src0[scale_idx + 1u];
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let dmin = load_src0_f16_at(block_byte_base + 2u);
|
||||
|
||||
// Load packed scales
|
||||
var scale_vals: array<u32, 3>;
|
||||
for (var i: u32 = 0u; i < 3u; i++) {
|
||||
let scale_0 = src0[scale_idx + 2u + (2u*i)];
|
||||
let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
|
||||
scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
|
||||
scale_vals[i] = load_src0_u32_at(block_byte_base + 4u + 4u * i);
|
||||
}
|
||||
|
||||
// Map k_in_block to loop structure:
|
||||
|
|
@ -567,9 +539,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
let ml = dmin * f16(mn);
|
||||
|
||||
let q_idx = q_b_idx + l;
|
||||
let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
|
||||
let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (q_idx / 4u));
|
||||
|
||||
let q_byte = get_byte(q_packed, q_idx % 4u);
|
||||
let qs_val = (q_byte >> shift) & 0xFu;
|
||||
|
|
@ -582,7 +552,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
|
||||
#ifdef INIT_SRC0_SHMEM_Q5_K
|
||||
const BLOCK_SIZE = 256u;
|
||||
const F16_PER_BLOCK = 88u;
|
||||
const BLOCK_SIZE_BYTES = 176u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
|
|
@ -601,17 +571,15 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = src0[scale_idx];
|
||||
let dmin = src0[scale_idx + 1u];
|
||||
let d = load_src0_f16_at(block_byte_base);
|
||||
let dmin = load_src0_f16_at(block_byte_base + 2u);
|
||||
|
||||
// Load packed scales
|
||||
var scale_vals: array<u32, 3>;
|
||||
for (var i: u32 = 0u; i < 3u; i++) {
|
||||
let scale_0 = src0[scale_idx + 2u + (2u*i)];
|
||||
let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
|
||||
scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
|
||||
scale_vals[i] = load_src0_u32_at(block_byte_base + 4u + 4u * i);
|
||||
}
|
||||
|
||||
// The original loop processes elements in groups of 64
|
||||
|
|
@ -651,15 +619,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
let ml = dmin * f16(mn);
|
||||
|
||||
let q_idx = q_b_idx + l;
|
||||
let q_0 = src0[scale_idx + 24u + 2u * (q_idx / 4u)];
|
||||
let q_1 = src0[scale_idx + 24u + 2u * (q_idx / 4u) + 1u];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_packed = load_src0_u32_at(block_byte_base + 48u + 4u * (q_idx / 4u));
|
||||
|
||||
let q_byte = get_byte(q_packed, q_idx % 4u);
|
||||
|
||||
let qh_0 = src0[scale_idx + 8u + 2u * (l / 4u)];
|
||||
let qh_1 = src0[scale_idx + 8u + 2u * (l / 4u) + 1u];
|
||||
let qh_packed = bitcast<u32>(vec2(qh_0, qh_1));
|
||||
let qh_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (l / 4u));
|
||||
|
||||
let qh_byte = get_byte(qh_packed, l % 4u);
|
||||
|
||||
|
|
@ -675,7 +639,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
|
||||
#ifdef INIT_SRC0_SHMEM_Q6_K
|
||||
const BLOCK_SIZE = 256u;
|
||||
const F16_PER_BLOCK = 105u;
|
||||
const BLOCK_SIZE_BYTES = 210u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
|
|
@ -694,7 +658,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let half = k_in_block / 128u;
|
||||
let pos_in_half = k_in_block % 128u;
|
||||
|
|
@ -707,30 +671,18 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
|
||||
// Load only ql13 word needed
|
||||
let ql13_flat = ql_b_idx + l;
|
||||
let ql13_word = ql13_flat / 4u;
|
||||
let ql13 = bitcast<u32>(vec2(
|
||||
src0[scale_idx + 2u * ql13_word],
|
||||
src0[scale_idx + 2u * ql13_word + 1u]
|
||||
));
|
||||
let ql13_b = get_byte(ql13, ql13_flat % 4u);
|
||||
let ql13 = load_src0_u32_at(block_byte_base + ql13_flat);
|
||||
let ql13_b = get_byte(ql13, 0u);
|
||||
|
||||
// Load only ql24 word needed
|
||||
let ql24_flat = ql_b_idx + l + 32u;
|
||||
let ql24_word = ql24_flat / 4u;
|
||||
let ql24 = bitcast<u32>(vec2(
|
||||
src0[scale_idx + 2u * ql24_word],
|
||||
src0[scale_idx + 2u * ql24_word + 1u]
|
||||
));
|
||||
let ql24_b = get_byte(ql24, ql24_flat % 4u);
|
||||
let ql24 = load_src0_u32_at(block_byte_base + ql24_flat);
|
||||
let ql24_b = get_byte(ql24, 0u);
|
||||
|
||||
// Load only qh word needed
|
||||
let qh_flat = qh_b_idx + l;
|
||||
let qh_word = qh_flat / 4u;
|
||||
let qh = bitcast<u32>(vec2(
|
||||
src0[scale_idx + 64u + 2u * qh_word],
|
||||
src0[scale_idx + 64u + 2u * qh_word + 1u]
|
||||
));
|
||||
let qh_b = get_byte(qh, qh_flat % 4u);
|
||||
let qh = load_src0_u32_at(block_byte_base + 128u + qh_flat);
|
||||
let qh_b = get_byte(qh, 0u);
|
||||
|
||||
let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0);
|
||||
let q2 = f16((ql24_b & 0xFu) | (((qh_b >> 2u) & 3u) << 4u)) - f16(32.0);
|
||||
|
|
@ -740,14 +692,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
// Load only the scale word needed
|
||||
let is = l / 16u;
|
||||
let sc_idx = sc_b_idx + is + quarter * 2u;
|
||||
let sc_word = sc_idx / 4u;
|
||||
let sc = bitcast<u32>(vec2(
|
||||
src0[scale_idx + 96u + 2u * sc_word],
|
||||
src0[scale_idx + 96u + 2u * sc_word + 1u]
|
||||
));
|
||||
let sc_val = get_byte_i32(sc, sc_idx % 4u);
|
||||
let sc = load_src0_u32_at(block_byte_base + 192u + sc_idx);
|
||||
let sc_val = get_byte_i32(sc, 0u);
|
||||
|
||||
let d = src0[scale_idx + 104u];
|
||||
let d = load_src0_f16_at(block_byte_base + 208u);
|
||||
|
||||
var q_val: f16;
|
||||
if (quarter == 0u) {
|
||||
|
|
|
|||
|
|
@ -52,8 +52,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
|||
#ifdef MUL_ACC_Q4_0
|
||||
|
||||
const BLOCK_SIZE = 32;
|
||||
const BLOCK_SIZE_BYTES = 18u;
|
||||
const NQ = 16u; // number of weights per thread
|
||||
const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
|
|
@ -62,14 +62,13 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
|||
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
|
||||
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(src0[scale_idx]);
|
||||
let d = f32(load_src0_f16_at(block_byte_base));
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = src0[scale_idx + 1 + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 1 + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d;
|
||||
|
|
@ -86,8 +85,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
|||
#ifdef MUL_ACC_Q4_1
|
||||
|
||||
const BLOCK_SIZE = 32;
|
||||
const BLOCK_SIZE_BYTES = 20u;
|
||||
const NQ = 16u; // number of weights per thread
|
||||
const F16_PER_BLOCK = 10u;
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
|
|
@ -96,15 +95,14 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
|||
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
|
||||
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(src0[scale_idx]);
|
||||
let m = f32(src0[scale_idx + 1u]);
|
||||
let d = f32(load_src0_f16_at(block_byte_base));
|
||||
let m = f32(load_src0_f16_at(block_byte_base + 2u));
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = src0[scale_idx + 2u + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
|
||||
|
|
@ -121,8 +119,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
|||
#ifdef MUL_ACC_Q5_0
|
||||
|
||||
const BLOCK_SIZE = 32;
|
||||
const BLOCK_SIZE_BYTES = 22u;
|
||||
const NQ = 16u; // number of weights per thread
|
||||
const F16_PER_BLOCK = 11u;
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
|
|
@ -131,18 +129,15 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
|||
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
|
||||
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(src0[scale_idx]);
|
||||
let qh0 = src0[scale_idx + 1u];
|
||||
let qh1 = src0[scale_idx + 2u];
|
||||
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
|
||||
let d = f32(load_src0_f16_at(block_byte_base));
|
||||
let qh_packed = load_src0_u32_at(block_byte_base + 2u);
|
||||
|
||||
for (var j = 0u; j < 2; j++) {
|
||||
let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];
|
||||
let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
|
||||
let j_adjusted = j + (block_offset / 2u);
|
||||
|
||||
|
|
@ -168,8 +163,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
|||
#ifdef MUL_ACC_Q5_1
|
||||
|
||||
const BLOCK_SIZE = 32;
|
||||
const BLOCK_SIZE_BYTES = 24u;
|
||||
const NQ = 16u; // number of weights per thread
|
||||
const F16_PER_BLOCK = 12u;
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
|
|
@ -178,19 +173,16 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
|||
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
|
||||
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(src0[scale_idx]);
|
||||
let m = src0[scale_idx + 1u];
|
||||
let qh0 = src0[scale_idx + 2u];
|
||||
let qh1 = src0[scale_idx + 3u];
|
||||
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
|
||||
let d = f32(load_src0_f16_at(block_byte_base));
|
||||
let m = load_src0_f16_at(block_byte_base + 2u);
|
||||
let qh_packed = load_src0_u32_at(block_byte_base + 4u);
|
||||
|
||||
for (var j = 0u; j < 2; j++) {
|
||||
let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];
|
||||
let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
|
||||
let j_adjusted = j + (block_offset / 2u);
|
||||
|
||||
|
|
@ -216,8 +208,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
|||
#ifdef MUL_ACC_Q8_0
|
||||
|
||||
const BLOCK_SIZE = 32;
|
||||
const BLOCK_SIZE_BYTES = 34u;
|
||||
const NQ = 16u; // number of weights per thread
|
||||
const F16_PER_BLOCK = 17u;
|
||||
const WEIGHTS_PER_F16 = 2u;
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
|
|
@ -226,15 +218,14 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
|||
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
|
||||
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(src0[scale_idx]);
|
||||
let d = f32(load_src0_f16_at(block_byte_base));
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = src0[scale_idx + 1 + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 1 + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f32(q_byte) * d;
|
||||
|
|
@ -250,8 +241,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
|||
#ifdef MUL_ACC_Q8_1
|
||||
|
||||
const BLOCK_SIZE = 32;
|
||||
const BLOCK_SIZE_BYTES = 36u;
|
||||
const NQ = 16u; // number of weights per thread
|
||||
const F16_PER_BLOCK = 18u;
|
||||
const WEIGHTS_PER_F16 = 2u;
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
|
|
@ -260,16 +251,15 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
|||
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
|
||||
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(src0[scale_idx]);
|
||||
let m = src0[scale_idx + 1u];
|
||||
let d = f32(load_src0_f16_at(block_byte_base));
|
||||
let m = load_src0_f16_at(block_byte_base + 2u);
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = src0[scale_idx + 2u + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
|
||||
let q_packed = load_src0_u32_at(q_byte_offset);
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f32(q_byte) * d + f32(m);
|
||||
|
|
@ -284,13 +274,7 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
|||
#ifdef MUL_ACC_Q6_K
|
||||
|
||||
const BLOCK_SIZE = 256u;
|
||||
const F16_PER_BLOCK = 105u;
|
||||
|
||||
fn load_u32_at(bbase: u32, byte_offset: u32) -> u32 {
|
||||
let aligned = byte_offset & ~3u;
|
||||
let idx = bbase + aligned / 2u;
|
||||
return bitcast<u32>(vec2(src0[idx], src0[idx + 1u]));
|
||||
}
|
||||
const BLOCK_SIZE_BYTES = 210u;
|
||||
|
||||
fn byte_of(v: u32, b: u32) -> u32 {
|
||||
return (v >> (b * 8u)) & 0xFFu;
|
||||
|
|
@ -323,16 +307,15 @@ fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
|||
var local_sum = 0.0;
|
||||
|
||||
for (var i = ix; i < nb; i += 2u) {
|
||||
let bbase = (idx_base + k_block_start + i) * F16_PER_BLOCK;
|
||||
let bbase = (idx_base + k_block_start + i) * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d_raw = load_u32_at(bbase, 208u);
|
||||
let d = f32(bitcast<vec2<f16>>(d_raw)[0]);
|
||||
let d = f32(load_src0_f16_at(bbase + 208u));
|
||||
|
||||
let ql1_u32 = load_u32_at(bbase, q_offset_l);
|
||||
let ql2_u32 = load_u32_at(bbase, q_offset_l + 32u);
|
||||
let qh_u32 = load_u32_at(bbase, 128u + q_offset_h);
|
||||
let sc_u32_0 = load_u32_at(bbase, sc_base_byte);
|
||||
let sc_u32_1 = load_u32_at(bbase, sc_base_byte + 4u);
|
||||
let ql1_u32 = load_src0_u32_at(bbase + q_offset_l);
|
||||
let ql2_u32 = load_src0_u32_at(bbase + q_offset_l + 32u);
|
||||
let qh_u32 = load_src0_u32_at(bbase + 128u + q_offset_h);
|
||||
let sc_u32_0 = load_src0_u32_at(bbase + sc_base_byte);
|
||||
let sc_u32_1 = load_src0_u32_at(bbase + sc_base_byte + 4u);
|
||||
|
||||
let sc0 = sbyte_of(sc_u32_0, sc_byte_pos);
|
||||
let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u);
|
||||
|
|
|
|||
|
|
@ -380,22 +380,33 @@ extern "C" {
|
|||
size_t n_samplers;
|
||||
};
|
||||
|
||||
struct llama_model_tensor_override {
|
||||
const char * pattern;
|
||||
enum ggml_type type;
|
||||
};
|
||||
|
||||
struct llama_model_imatrix_data {
|
||||
const char * name;
|
||||
const float * data;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
// model quantization parameters
|
||||
typedef struct llama_model_quantize_params {
|
||||
int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
|
||||
enum llama_ftype ftype; // quantize to this llama_ftype
|
||||
enum ggml_type output_tensor_type; // output tensor type
|
||||
enum ggml_type token_embedding_type; // token embeddings tensor type
|
||||
bool allow_requantize; // allow quantizing non-f32/f16 tensors
|
||||
bool quantize_output_tensor; // quantize output.weight
|
||||
bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
|
||||
bool pure; // quantize all tensors to the default type
|
||||
bool keep_split; // quantize to the same number of shards
|
||||
bool dry_run; // calculate and show the final quantization size without performing quantization
|
||||
void * imatrix; // pointer to importance matrix data
|
||||
void * kv_overrides; // pointer to vector containing overrides
|
||||
void * tensor_types; // pointer to vector containing tensor types
|
||||
void * prune_layers; // pointer to vector containing layer indices to prune
|
||||
int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
|
||||
enum llama_ftype ftype; // quantize to this llama_ftype
|
||||
enum ggml_type output_tensor_type; // output tensor type
|
||||
enum ggml_type token_embedding_type; // token embeddings tensor type
|
||||
bool allow_requantize; // allow quantizing non-f32/f16 tensors
|
||||
bool quantize_output_tensor; // quantize output.weight
|
||||
bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
|
||||
bool pure; // quantize all tensors to the default type
|
||||
bool keep_split; // quantize to the same number of shards
|
||||
bool dry_run; // calculate and show the final quantization size without performing quantization
|
||||
const struct llama_model_imatrix_data * imatrix; // pointer to importance matrix data
|
||||
const struct llama_model_kv_override * kv_overrides; // pointer to kv overrides
|
||||
const struct llama_model_tensor_override * tt_overrides; // pointer to tensor overrides
|
||||
const int32_t * prune_layers; // pointer to layer indices to prune
|
||||
} llama_model_quantize_params;
|
||||
|
||||
typedef struct llama_logit_bias {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,45 @@
|
|||
{{- bos_token -}}
|
||||
{%- set keep_past_thinking = keep_past_thinking | default(false) -%}
|
||||
{%- set ns = namespace(system_prompt="") -%}
|
||||
{%- if messages[0]["role"] == "system" -%}
|
||||
{%- set ns.system_prompt = messages[0]["content"] -%}
|
||||
{%- set messages = messages[1:] -%}
|
||||
{%- endif -%}
|
||||
{%- if tools -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + ("\n" if ns.system_prompt else "") + "List of tools: [" -%}
|
||||
{%- for tool in tools -%}
|
||||
{%- if tool is not string -%}
|
||||
{%- set tool = tool | tojson -%}
|
||||
{%- endif -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + tool -%}
|
||||
{%- if not loop.last -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + ", " -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- set ns.system_prompt = ns.system_prompt + "]" -%}
|
||||
{%- endif -%}
|
||||
{%- if ns.system_prompt -%}
|
||||
{{- "<|im_start|>system\n" + ns.system_prompt + "<|im_end|>\n" -}}
|
||||
{%- endif -%}
|
||||
{%- set ns.last_assistant_index = -1 -%}
|
||||
{%- for message in messages -%}
|
||||
{%- if message["role"] == "assistant" -%}
|
||||
{%- set ns.last_assistant_index = loop.index0 -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- for message in messages -%}
|
||||
{{- "<|im_start|>" + message["role"] + "\n" -}}
|
||||
{%- set content = message["content"] -%}
|
||||
{%- if content is not string -%}
|
||||
{%- set content = content | tojson -%}
|
||||
{%- endif -%}
|
||||
{%- if message["role"] == "assistant" and not keep_past_thinking and loop.index0 != ns.last_assistant_index -%}
|
||||
{%- if "</think>" in content -%}
|
||||
{%- set content = content.split("</think>")[-1] | trim -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{{- content + "<|im_end|>\n" -}}
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
{{- "<|im_start|>assistant\n" -}}
|
||||
{%- endif -%}
|
||||
|
|
@ -139,7 +139,11 @@ def main():
|
|||
'_ZL18flash_attn_ext_f16ILi96ELi96ELi4ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL18flash_attn_ext_vecILi128ELi2EL9ggml_type2ELS0_2ELb0EEvPKcS2_S2_S2_S2_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS6_IjLj3EEiiiiiiiiiiiliiliiiiil',
|
||||
'_ZL9mul_mat_qIL9ggml_type10ELi16ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
|
||||
'_ZL9mul_mat_qIL9ggml_type12ELi128ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii'
|
||||
'_ZL9mul_mat_qIL9ggml_type12ELi128ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
|
||||
'_ZL9mul_mat_qIL9ggml_type40ELi112ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
|
||||
'_ZL9mul_mat_qIL9ggml_type40ELi112ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
|
||||
'_ZL9mul_mat_qIL9ggml_type40ELi128ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
|
||||
'_ZL9mul_mat_qIL9ggml_type40ELi128ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii'
|
||||
}
|
||||
|
||||
functions = parse_log_file(log_file)
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -1 +1 @@
|
|||
a04eea0761a85d18f3f504d6ab970c5c9dce705f
|
||||
50634c28837c24ac68b380b5750b41e701c87d73
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@
|
|||
|
||||
// dedup helpers
|
||||
|
||||
static ggml_tensor * build_kq_mask(
|
||||
static ggml_tensor * build_attn_inp_kq_mask(
|
||||
ggml_context * ctx,
|
||||
const llama_kv_cache_context * mctx,
|
||||
const llama_ubatch & ubatch,
|
||||
|
|
@ -28,7 +28,11 @@ static ggml_tensor * build_kq_mask(
|
|||
const auto n_tokens = ubatch.n_tokens;
|
||||
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
||||
|
||||
return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
ggml_tensor * res = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
ggml_set_input(res);
|
||||
ggml_set_name(res, "attn_inp_kq_mask");
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
static bool can_reuse_kq_mask(
|
||||
|
|
@ -52,6 +56,21 @@ static bool can_reuse_kq_mask(
|
|||
|
||||
// impl
|
||||
|
||||
static ggml_tensor * ggml_mul_mat_aux(
|
||||
ggml_context * ctx,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * rot) {
|
||||
const auto n = rot->ne[0];
|
||||
|
||||
ggml_tensor * res;
|
||||
|
||||
res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n);
|
||||
res = ggml_mul_mat (ctx, rot, res);
|
||||
res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
||||
if (ubatch->token) {
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
|
|
@ -429,6 +448,14 @@ void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
|
|||
mctx->set_input_v_idxs(self_v_idxs, ubatch);
|
||||
|
||||
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
|
||||
if (self_k_rot) {
|
||||
mctx->set_input_k_rot(self_k_rot);
|
||||
}
|
||||
|
||||
if (self_v_rot) {
|
||||
mctx->set_input_v_rot(self_v_rot);
|
||||
}
|
||||
}
|
||||
|
||||
bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
|
||||
|
|
@ -476,6 +503,14 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
|
|||
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
|
||||
|
||||
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||
|
||||
if (self_k_rot) {
|
||||
mctx->get_base()->set_input_k_rot(self_k_rot);
|
||||
}
|
||||
|
||||
if (self_v_rot) {
|
||||
mctx->get_base()->set_input_v_rot(self_v_rot);
|
||||
}
|
||||
}
|
||||
|
||||
bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
|
||||
|
|
@ -532,6 +567,14 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
|||
|
||||
mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
|
||||
|
||||
if (inp_attn->self_k_rot) {
|
||||
mctx->get_attn()->set_input_k_rot(inp_attn->self_k_rot);
|
||||
}
|
||||
|
||||
if (inp_attn->self_v_rot) {
|
||||
mctx->get_attn()->set_input_v_rot(inp_attn->self_v_rot);
|
||||
}
|
||||
|
||||
const int64_t n_rs = mctx->get_recr()->get_n_rs();
|
||||
|
||||
if (inp_rs->s_copy) {
|
||||
|
|
@ -630,6 +673,14 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
|
|||
attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||
}
|
||||
|
||||
if (inp_attn->self_k_rot) {
|
||||
attn_ctx->get_base()->set_input_k_rot(inp_attn->self_k_rot);
|
||||
}
|
||||
|
||||
if (inp_attn->self_v_rot) {
|
||||
attn_ctx->get_base()->set_input_v_rot(inp_attn->self_v_rot);
|
||||
}
|
||||
|
||||
const int64_t n_rs = mctx->get_recr()->get_n_rs();
|
||||
|
||||
if (inp_rs->s_copy) {
|
||||
|
|
@ -2002,13 +2053,13 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
|
|||
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
|
||||
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams);
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
}
|
||||
|
||||
inp->self_k_rot = mctx_cur->build_input_k_rot(ctx0);
|
||||
inp->self_v_rot = mctx_cur->build_input_v_rot(ctx0);
|
||||
|
||||
return inp;
|
||||
}
|
||||
|
||||
|
|
@ -2034,6 +2085,15 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
int il) const {
|
||||
GGML_ASSERT(v_mla == nullptr);
|
||||
|
||||
if (inp->self_k_rot) {
|
||||
q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot);
|
||||
k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot);
|
||||
}
|
||||
|
||||
if (inp->self_v_rot) {
|
||||
v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot);
|
||||
}
|
||||
|
||||
// these nodes are added to the graph together so that they are not reordered
|
||||
// by doing so, the number of splits in the graph is reduced
|
||||
// expand k later to enable rope fusion which directly writes into k-v cache
|
||||
|
|
@ -2061,6 +2121,10 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
if (inp->self_v_rot) {
|
||||
cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot);
|
||||
}
|
||||
|
||||
if (wo) {
|
||||
cur = build_lora_mm(wo, cur);
|
||||
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) {
|
||||
|
|
@ -2090,9 +2154,7 @@ static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl(
|
|||
|
||||
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams);
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
}
|
||||
|
||||
|
|
@ -2171,6 +2233,18 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
ggml_tensor * v_mla,
|
||||
float kq_scale,
|
||||
int il) const {
|
||||
if (inp->self_k_rot) {
|
||||
q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot);
|
||||
if (k_cur) {
|
||||
k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot);
|
||||
}
|
||||
}
|
||||
if (inp->self_v_rot) {
|
||||
if (v_cur) {
|
||||
v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot);
|
||||
}
|
||||
}
|
||||
|
||||
// these nodes are added to the graph together so that they are not reordered
|
||||
// by doing so, the number of splits in the graph is reduced
|
||||
ggml_build_forward_expand(gf, q_cur);
|
||||
|
|
@ -2211,6 +2285,10 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
if (inp->self_v_rot) {
|
||||
cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot);
|
||||
}
|
||||
|
||||
if (wo) {
|
||||
cur = build_lora_mm(wo, cur);
|
||||
}
|
||||
|
|
@ -2293,12 +2371,8 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
|
|||
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
ggml_set_name(inp->self_kq_mask, "self_kq_mask");
|
||||
|
||||
inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams);
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv");
|
||||
}
|
||||
|
||||
{
|
||||
|
|
@ -2307,14 +2381,13 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
|
|||
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask_swa = build_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams);
|
||||
ggml_set_input(inp->self_kq_mask_swa);
|
||||
ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa");
|
||||
|
||||
inp->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams);
|
||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||
ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv");
|
||||
}
|
||||
|
||||
inp->self_k_rot = mctx_cur->get_base()->build_input_k_rot(ctx0);
|
||||
inp->self_v_rot = mctx_cur->get_base()->build_input_v_rot(ctx0);
|
||||
|
||||
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
|
|
@ -2473,9 +2546,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa()
|
|||
inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp_attn->self_kq_mask = build_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams);
|
||||
ggml_set_input(inp_attn->self_kq_mask);
|
||||
|
||||
inp_attn->self_kq_mask = build_attn_inp_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams);
|
||||
inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask;
|
||||
}
|
||||
|
||||
|
|
@ -2483,9 +2554,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa()
|
|||
inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp_attn->self_kq_mask_swa = build_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams);
|
||||
ggml_set_input(inp_attn->self_kq_mask_swa);
|
||||
|
||||
inp_attn->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams);
|
||||
inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -308,6 +308,10 @@ public:
|
|||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
|
||||
// note: assumes v_rot^ == I
|
||||
ggml_tensor * self_k_rot = nullptr;
|
||||
ggml_tensor * self_v_rot = nullptr;
|
||||
|
||||
// note: these have to be copies because in order to be able to reuse a graph, its inputs
|
||||
// need to carry these parameters with them. otherwise, they can point to freed
|
||||
// llm_graph_params from a previous batch, causing stack-use-after-return
|
||||
|
|
@ -384,6 +388,10 @@ public:
|
|||
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
|
||||
// note: using same rotation matrices for both base and swa cache
|
||||
ggml_tensor * self_k_rot = nullptr;
|
||||
ggml_tensor * self_v_rot = nullptr;
|
||||
|
||||
const llama_hparams hparams;
|
||||
const llama_cparams cparams;
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,65 @@
|
|||
#include <map>
|
||||
#include <stdexcept>
|
||||
|
||||
static bool ggml_is_power_of_2(int n) {
|
||||
return (n & (n - 1)) == 0;
|
||||
}
|
||||
|
||||
// orthonormal Walsh-Hadamard rotation matrix
|
||||
// note: res^2 == I
|
||||
static void ggml_gen_hadamard(ggml_tensor * tensor) {
|
||||
assert(tensor->type == GGML_TYPE_F32);
|
||||
|
||||
const int n = tensor->ne[0];
|
||||
|
||||
assert(ggml_is_power_of_2(n));
|
||||
assert(tensor->ne[1] == n);
|
||||
assert(tensor->ne[2] == 1);
|
||||
assert(tensor->ne[3] == 1);
|
||||
|
||||
std::vector<float> data_f32;
|
||||
|
||||
float * data = (float *) tensor->data;
|
||||
|
||||
if (tensor->type != GGML_TYPE_F32) {
|
||||
data_f32.resize(n*n);
|
||||
data = data_f32.data();
|
||||
}
|
||||
|
||||
data[0*n + 0] = 1.0 / sqrtf(n);
|
||||
|
||||
for (int s = 1; s < n; s *= 2) {
|
||||
for (int i = 0; i < s; i++) {
|
||||
for (int j = 0; j < s; j++) {
|
||||
const float val = data[i*n + j];
|
||||
|
||||
data[(i + s)*n + (j )] = val;
|
||||
data[(i )*n + (j + s)] = val;
|
||||
data[(i + s)*n + (j + s)] = -val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (tensor->type != GGML_TYPE_F32) {
|
||||
ggml_quantize_chunk(tensor->type, data, tensor->data, 0, 1, n*n, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
static ggml_tensor * ggml_mul_mat_aux(
|
||||
ggml_context * ctx,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * rot) {
|
||||
const auto n = rot->ne[0];
|
||||
|
||||
ggml_tensor * res;
|
||||
|
||||
res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n);
|
||||
res = ggml_mul_mat (ctx, rot, res);
|
||||
res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
//
|
||||
// llama_kv_cache
|
||||
//
|
||||
|
|
@ -209,6 +268,48 @@ llama_kv_cache::llama_kv_cache(
|
|||
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
||||
}
|
||||
|
||||
const char * LLAMA_ATTN_ROT_DISABLE = getenv("LLAMA_ATTN_ROT_DISABLE");
|
||||
const bool attn_rot_disable = LLAMA_ATTN_ROT_DISABLE ? atoi(LLAMA_ATTN_ROT_DISABLE) : false;
|
||||
if (attn_rot_disable) {
|
||||
LLAMA_LOG_WARN("%s: attention rotation force disabled (LLAMA_ATTN_ROT_DISABLE)\n", __func__);
|
||||
}
|
||||
|
||||
attn_rot_k =
|
||||
!attn_rot_disable &&
|
||||
ggml_is_quantized(type_k) &&
|
||||
!hparams.is_n_embd_k_gqa_variable() &&
|
||||
hparams.n_embd_head_k() % 64 == 0;
|
||||
|
||||
attn_rot_v =
|
||||
!attn_rot_disable &&
|
||||
ggml_is_quantized(type_v) &&
|
||||
!hparams.is_n_embd_v_gqa_variable() &&
|
||||
hparams.n_embd_head_v() % 64 == 0;
|
||||
|
||||
LLAMA_LOG_INFO("%s: attn_rot_k = %d\n", __func__, attn_rot_k);
|
||||
LLAMA_LOG_INFO("%s: attn_rot_v = %d\n", __func__, attn_rot_v);
|
||||
|
||||
// pre-compute the haramard matrices and keep them in host memory
|
||||
// TODO: in the future, we can make copies in the backend buffers to avoid host -> device transfers
|
||||
if (attn_rot_k || attn_rot_v) {
|
||||
for (int64_t n = 64; n <= std::max(hparams.n_embd_head_k(), hparams.n_embd_head_v()); n *= 2) {
|
||||
attn_rot_hadamard[n] = std::vector<float>(n*n);
|
||||
|
||||
ggml_init_params params = {
|
||||
/* .mem_size = */ 1*ggml_tensor_overhead(),
|
||||
/* .mem_buffer = */ nullptr,
|
||||
/* .no_alloc = */ true,
|
||||
};
|
||||
|
||||
ggml_context_ptr ctx { ggml_init(params) };
|
||||
|
||||
ggml_tensor * tmp = ggml_new_tensor_2d(ctx.get(), GGML_TYPE_F32, n, n);
|
||||
tmp->data = attn_rot_hadamard[n].data();
|
||||
|
||||
ggml_gen_hadamard(tmp);
|
||||
}
|
||||
}
|
||||
|
||||
const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
|
||||
debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
|
||||
}
|
||||
|
|
@ -1004,6 +1105,14 @@ bool llama_kv_cache::get_has_shift() const {
|
|||
return result;
|
||||
}
|
||||
|
||||
ggml_type llama_kv_cache::type_k() const {
|
||||
return layers[0].k->type;
|
||||
}
|
||||
|
||||
ggml_type llama_kv_cache::type_v() const {
|
||||
return layers[0].v->type;
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
|
||||
uint32_t result = 0;
|
||||
|
||||
|
|
@ -1189,6 +1298,47 @@ ggml_tensor * llama_kv_cache::build_input_v_idxs(ggml_context * ctx, const llama
|
|||
return v_idxs;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache::build_input_k_rot(ggml_context * ctx) const {
|
||||
ggml_tensor * res = nullptr;
|
||||
|
||||
if (attn_rot_k) {
|
||||
int nrot = 64;
|
||||
|
||||
// TODO: investigate if using the smallest rotation matrix is beneficial also for K (similar as for V)
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/21038#issuecomment-4141323088
|
||||
do {
|
||||
nrot *= 2;
|
||||
} while (hparams.n_embd_head_k() % nrot == 0);
|
||||
nrot /= 2;
|
||||
|
||||
res = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nrot, nrot);
|
||||
ggml_set_input(res);
|
||||
ggml_set_name(res, "attn_inp_k_rot");
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache::build_input_v_rot(ggml_context * ctx) const {
|
||||
ggml_tensor * res = nullptr;
|
||||
|
||||
if (attn_rot_v) {
|
||||
int nrot = 64;
|
||||
// using smaller rotation matrices for V seems beneficial
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/21038#issuecomment-4146397570
|
||||
//do {
|
||||
// nrot *= 2;
|
||||
//} while (hparams.n_embd_head_v() % nrot == 0);
|
||||
//nrot /= 2;
|
||||
|
||||
res = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nrot, nrot);
|
||||
ggml_set_input(res);
|
||||
ggml_set_name(res, "attn_inp_v_rot");
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
|
||||
const uint32_t n_tokens = ubatch->n_tokens;
|
||||
GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
|
||||
|
|
@ -1507,6 +1657,24 @@ void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch
|
|||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache::set_input_k_rot(ggml_tensor * dst) const {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
||||
|
||||
const auto n_rot = dst->ne[0];
|
||||
GGML_ASSERT(attn_rot_hadamard.count(dst->ne[0]));
|
||||
|
||||
memcpy(dst->data, attn_rot_hadamard.at(n_rot).data(), ggml_nbytes(dst));
|
||||
}
|
||||
|
||||
void llama_kv_cache::set_input_v_rot(ggml_tensor * dst) const {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
||||
|
||||
const auto n_rot = dst->ne[0];
|
||||
GGML_ASSERT(attn_rot_hadamard.count(dst->ne[0]));
|
||||
|
||||
memcpy(dst->data, attn_rot_hadamard.at(n_rot).data(), ggml_nbytes(dst));
|
||||
}
|
||||
|
||||
size_t llama_kv_cache::total_size() const {
|
||||
size_t size = 0;
|
||||
|
||||
|
|
@ -1542,6 +1710,7 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
|
|||
ggml_context * ctx,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * shift,
|
||||
ggml_tensor * rot,
|
||||
ggml_tensor * factors,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
|
|
@ -1567,10 +1736,16 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
|
|||
// dequantize to f32 -> RoPE -> quantize back
|
||||
tmp = ggml_cast(ctx, cur, GGML_TYPE_F32);
|
||||
|
||||
// rotate back
|
||||
tmp = ggml_mul_mat_aux(ctx, tmp, rot);
|
||||
|
||||
tmp = ggml_rope_ext(ctx, tmp,
|
||||
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
|
||||
|
||||
// rotate fwd
|
||||
tmp = ggml_mul_mat_aux(ctx, tmp, rot);
|
||||
|
||||
tmp = ggml_cpy(ctx, tmp, cur);
|
||||
} else {
|
||||
// we rotate only the first n_rot dimensions
|
||||
|
|
@ -1591,6 +1766,9 @@ public:
|
|||
|
||||
ggml_tensor * k_shift; // I32 [kv_size*n_stream]
|
||||
|
||||
// note: assumes k_rot^2 == I
|
||||
ggml_tensor * k_rot = nullptr;
|
||||
|
||||
const llama_kv_cache * kv_self;
|
||||
};
|
||||
|
||||
|
|
@ -1600,6 +1778,10 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
|
|||
if (k_shift) {
|
||||
kv_self->set_input_k_shift(k_shift);
|
||||
}
|
||||
|
||||
if (k_rot) {
|
||||
kv_self->set_input_k_rot(k_rot);
|
||||
}
|
||||
}
|
||||
|
||||
ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
|
||||
|
|
@ -1611,6 +1793,8 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
|
|||
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
|
||||
ggml_set_input(inp->k_shift);
|
||||
|
||||
inp->k_rot = build_input_k_rot(ctx);
|
||||
|
||||
const auto & cparams = lctx->get_cparams();
|
||||
|
||||
for (const auto & layer : layers) {
|
||||
|
|
@ -1635,7 +1819,7 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
|
|||
ggml_row_size(layer.k->type, n_embd_k_gqa),
|
||||
ggml_row_size(layer.k->type, n_embd_nope));
|
||||
|
||||
ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, il);
|
||||
ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, inp->k_rot, rope_factors, freq_base_l, freq_scale_l, il);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
|
@ -2239,6 +2423,14 @@ uint32_t llama_kv_cache_context::get_n_kv() const {
|
|||
return n_kv;
|
||||
}
|
||||
|
||||
ggml_type llama_kv_cache_context::type_k() const {
|
||||
return kv->type_k();
|
||||
}
|
||||
|
||||
ggml_type llama_kv_cache_context::type_v() const {
|
||||
return kv->type_v();
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_context::get_k(ggml_context * ctx, int32_t il) const {
|
||||
return kv->get_k(ctx, il, n_kv, sinfos[i_cur]);
|
||||
}
|
||||
|
|
@ -2263,6 +2455,14 @@ ggml_tensor * llama_kv_cache_context::build_input_v_idxs(ggml_context * ctx, con
|
|||
return kv->build_input_v_idxs(ctx, ubatch);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_context::build_input_k_rot(ggml_context * ctx) const {
|
||||
return kv->build_input_k_rot(ctx);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_context::build_input_v_rot(ggml_context * ctx) const {
|
||||
return kv->build_input_v_rot(ctx);
|
||||
}
|
||||
|
||||
void llama_kv_cache_context::set_input_k_shift(ggml_tensor * dst) const {
|
||||
kv->set_input_k_shift(dst);
|
||||
}
|
||||
|
|
@ -2282,3 +2482,11 @@ void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ub
|
|||
void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||
kv->set_input_pos_bucket(dst, ubatch);
|
||||
}
|
||||
|
||||
void llama_kv_cache_context::set_input_k_rot(ggml_tensor * dst) const {
|
||||
kv->set_input_k_rot(dst);
|
||||
}
|
||||
|
||||
void llama_kv_cache_context::set_input_v_rot(ggml_tensor * dst) const {
|
||||
kv->set_input_v_rot(dst);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -152,6 +152,9 @@ public:
|
|||
|
||||
bool get_has_shift() const;
|
||||
|
||||
ggml_type type_k() const;
|
||||
ggml_type type_v() const;
|
||||
|
||||
//
|
||||
// graph_build API
|
||||
//
|
||||
|
|
@ -191,6 +194,9 @@ public:
|
|||
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||
|
||||
ggml_tensor * build_input_k_rot(ggml_context * ctx) const;
|
||||
ggml_tensor * build_input_v_rot(ggml_context * ctx) const;
|
||||
|
||||
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
|
||||
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
|
||||
|
||||
|
|
@ -199,6 +205,9 @@ public:
|
|||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
|
||||
void set_input_k_rot(ggml_tensor * dst) const;
|
||||
void set_input_v_rot(ggml_tensor * dst) const;
|
||||
|
||||
private:
|
||||
const llama_model & model;
|
||||
const llama_hparams & hparams;
|
||||
|
|
@ -226,6 +235,13 @@ private:
|
|||
// SWA
|
||||
const uint32_t n_swa = 0;
|
||||
|
||||
// env: LLAMA_ATTN_ROT_DISABLE
|
||||
bool attn_rot_k = false;
|
||||
bool attn_rot_v = false;
|
||||
|
||||
// pre-computed hadamard martrices
|
||||
std::unordered_map<int64_t, std::vector<float>> attn_rot_hadamard;
|
||||
|
||||
// env: LLAMA_KV_CACHE_DEBUG
|
||||
int debug = 0;
|
||||
|
||||
|
|
@ -262,6 +278,7 @@ private:
|
|||
ggml_context * ctx,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * shift,
|
||||
ggml_tensor * rot,
|
||||
ggml_tensor * factors,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
|
|
@ -328,6 +345,9 @@ public:
|
|||
|
||||
uint32_t get_n_kv() const;
|
||||
|
||||
ggml_type type_k() const;
|
||||
ggml_type type_v() const;
|
||||
|
||||
// get views of the current state of the cache
|
||||
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
|
||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
|
||||
|
|
@ -347,6 +367,9 @@ public:
|
|||
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||
|
||||
ggml_tensor * build_input_k_rot(ggml_context * ctx) const;
|
||||
ggml_tensor * build_input_v_rot(ggml_context * ctx) const;
|
||||
|
||||
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
|
||||
|
|
@ -354,6 +377,9 @@ public:
|
|||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
|
||||
void set_input_k_rot(ggml_tensor * dst) const;
|
||||
void set_input_v_rot(ggml_tensor * dst) const;
|
||||
|
||||
private:
|
||||
llama_memory_status status;
|
||||
|
||||
|
|
|
|||
|
|
@ -73,9 +73,9 @@ llama_memory_context_ptr llama_memory_hybrid_iswa::init_batch(llama_batch_allocr
|
|||
// if all tokens are output, split by sequence
|
||||
ubatch = balloc.split_seq(n_ubatch);
|
||||
} else {
|
||||
// TODO: non-sequential equal split can be done if using unified KV cache
|
||||
// for simplicity, we always use sequential equal split for now
|
||||
ubatch = balloc.split_equal(n_ubatch, true);
|
||||
// Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice)
|
||||
const bool unified = (mem_attn->get_base()->get_n_stream() == 1);
|
||||
ubatch = balloc.split_equal(n_ubatch, !unified);
|
||||
}
|
||||
|
||||
if (ubatch.n_tokens == 0) {
|
||||
|
|
|
|||
|
|
@ -73,9 +73,9 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
|
|||
// if all tokens are output, split by sequence
|
||||
ubatch = balloc.split_seq(n_ubatch);
|
||||
} else {
|
||||
// TODO: non-sequential equal split can be done if using unified KV cache
|
||||
// for simplicity, we always use sequential equal split for now
|
||||
ubatch = balloc.split_equal(n_ubatch, true);
|
||||
// Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice)
|
||||
const bool unified = (mem_attn->get_n_stream() == 1);
|
||||
ubatch = balloc.split_equal(n_ubatch, !unified);
|
||||
}
|
||||
|
||||
if (ubatch.n_tokens == 0) {
|
||||
|
|
|
|||
|
|
@ -84,7 +84,6 @@ static std::string remap_imatrix(const std::string & orig_name, const std::map<i
|
|||
|
||||
for (const auto & p : mapped) {
|
||||
if (p.second == blk) {
|
||||
LLAMA_LOG_DEBUG("(blk.%d imatrix) ", p.first);
|
||||
return new_name.replace(match.position(1), match.length(1), std::to_string(p.first));
|
||||
}
|
||||
}
|
||||
|
|
@ -188,10 +187,9 @@ struct quantize_state_impl {
|
|||
model(model), params(params)
|
||||
{
|
||||
// compile regex patterns once - they are expensive
|
||||
if (params->tensor_types) {
|
||||
const auto & tensor_types = *static_cast<const std::vector<tensor_type_option> *>(params->tensor_types);
|
||||
for (const auto & [tname, qtype] : tensor_types) {
|
||||
tensor_type_patterns.emplace_back(std::regex(tname), qtype);
|
||||
if (params->tt_overrides) {
|
||||
for (const auto * p = params->tt_overrides; p->pattern != nullptr; p++) {
|
||||
tensor_type_patterns.emplace_back(std::regex(p->pattern), p->type);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -857,12 +855,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
constexpr bool use_mmap = false;
|
||||
#endif
|
||||
|
||||
llama_model_kv_override * kv_overrides = nullptr;
|
||||
if (params->kv_overrides) {
|
||||
auto * v = (std::vector<llama_model_kv_override>*)params->kv_overrides;
|
||||
kv_overrides = v->data();
|
||||
}
|
||||
|
||||
const llama_model_kv_override * kv_overrides = params->kv_overrides;
|
||||
std::vector<std::string> splits = {};
|
||||
llama_model_loader ml(/*metadata*/ nullptr, /*set_tensor_data*/ nullptr, /*set_tensor_data_ud*/ nullptr,
|
||||
fname_inp, splits, /*file*/ nullptr, use_mmap, /*use_direct_io*/ false, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr);
|
||||
|
|
@ -879,9 +872,13 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
if (params->only_copy) {
|
||||
ftype = ml.ftype;
|
||||
}
|
||||
std::unordered_map<std::string, std::vector<float>> i_data;
|
||||
const std::unordered_map<std::string, std::vector<float>> * imatrix_data = nullptr;
|
||||
if (params->imatrix) {
|
||||
imatrix_data = static_cast<const std::unordered_map<std::string, std::vector<float>>*>(params->imatrix);
|
||||
for (const llama_model_imatrix_data * p = params->imatrix; p->name != nullptr; p++) {
|
||||
i_data.emplace(p->name, std::vector<float>(p->data, p->data + p->size));
|
||||
}
|
||||
imatrix_data = & i_data;
|
||||
if (imatrix_data) {
|
||||
LLAMA_LOG_INFO("\n%s: have importance matrix data with %d entries\n",
|
||||
__func__, (int)imatrix_data->size());
|
||||
|
|
@ -902,7 +899,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
|
||||
std::vector<int> prune_list = {};
|
||||
if (params->prune_layers) {
|
||||
prune_list = *static_cast<const std::vector<int> *>(params->prune_layers);
|
||||
for (const int32_t * p = params->prune_layers; * p != -1; p++) {
|
||||
prune_list.push_back(* p);
|
||||
}
|
||||
}
|
||||
|
||||
// copy the KV pairs from the input file
|
||||
|
|
@ -916,20 +915,18 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str());
|
||||
|
||||
if (params->kv_overrides) {
|
||||
const std::vector<llama_model_kv_override> & overrides = *(const std::vector<llama_model_kv_override> *)params->kv_overrides;
|
||||
for (const auto & o : overrides) {
|
||||
if (o.key[0] == 0) break;
|
||||
if (o.tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) {
|
||||
gguf_set_val_f32(ctx_out.get(), o.key, o.val_f64);
|
||||
} else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_INT) {
|
||||
for (const llama_model_kv_override * o = params->kv_overrides; o->key[0] != 0; ++o) {
|
||||
if (o->tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) {
|
||||
gguf_set_val_f32(ctx_out.get(), o->key, o->val_f64);
|
||||
} else if (o->tag == LLAMA_KV_OVERRIDE_TYPE_INT) {
|
||||
// Setting type to UINT32. See https://github.com/ggml-org/llama.cpp/pull/14182 for context
|
||||
gguf_set_val_u32(ctx_out.get(), o.key, (uint32_t)std::abs(o.val_i64));
|
||||
} else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) {
|
||||
gguf_set_val_bool(ctx_out.get(), o.key, o.val_bool);
|
||||
} else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_STR) {
|
||||
gguf_set_val_str(ctx_out.get(), o.key, o.val_str);
|
||||
gguf_set_val_u32(ctx_out.get(), o->key, (uint32_t)std::abs(o->val_i64));
|
||||
} else if (o->tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) {
|
||||
gguf_set_val_bool(ctx_out.get(), o->key, o->val_bool);
|
||||
} else if (o->tag == LLAMA_KV_OVERRIDE_TYPE_STR) {
|
||||
gguf_set_val_str(ctx_out.get(), o->key, o->val_str);
|
||||
} else {
|
||||
LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o.key);
|
||||
LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o->key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2712,6 +2712,67 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
|||
.run();
|
||||
}
|
||||
|
||||
// LFM2.5 tests - uses plain "List of tools: [...]" and bare [name(args)] without wrapper tokens
|
||||
{
|
||||
auto tst = peg_tester("models/templates/LFM2.5-Instruct.jinja", detailed_debug);
|
||||
|
||||
// Basic content only
|
||||
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
|
||||
|
||||
// Single tool call without reasoning
|
||||
tst.test("[special_function(arg1=1)]")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
.run();
|
||||
|
||||
// Tool call with string argument
|
||||
tst.test("[get_time(city=\"XYZCITY\")]")
|
||||
.tools({ get_time_tool })
|
||||
.expect(message_with_tool_calls("get_time", "{\"city\":\"XYZCITY\"}"))
|
||||
.run();
|
||||
|
||||
// Tool call with reasoning (enable_thinking=true)
|
||||
tst.test("<think>I'm\nthinking</think>[special_function(arg1=1)]")
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call_thoughts)
|
||||
.run();
|
||||
|
||||
// Multiple tool calls (parallel)
|
||||
tst.test("[special_function(arg1=1), special_function_with_opt(arg1=1, arg2=2)]")
|
||||
.parallel_tool_calls(true)
|
||||
.tools({
|
||||
special_function_tool, special_function_tool_with_optional_param
|
||||
})
|
||||
.expect_tool_calls({
|
||||
{ "special_function", R"({"arg1": 1})", {} },
|
||||
{ "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} },
|
||||
})
|
||||
.run();
|
||||
|
||||
// Tool call with content before tool call
|
||||
tst.test("Let me check the time.[get_time(city=\"Paris\")]")
|
||||
.tools({ get_time_tool })
|
||||
.expect(message_with_reasoning_content_and_multiple_tool_calls(
|
||||
"", "Let me check the time.", { { "get_time", "{\"city\":\"Paris\"}" } }
|
||||
))
|
||||
.run();
|
||||
|
||||
// Partial tool call (streaming)
|
||||
tst.test("[special_function(arg1=")
|
||||
.tools({ special_function_tool })
|
||||
.is_partial(true)
|
||||
.expect(simple_assist_msg("", "", "special_function", "{\"arg1\": "))
|
||||
.run();
|
||||
|
||||
// Tool call with empty arguments
|
||||
tst.test("[empty_args()]")
|
||||
.tools({ empty_args_tool })
|
||||
.expect(simple_assist_msg("", "", "empty_args", "{}"))
|
||||
.run();
|
||||
}
|
||||
|
||||
// Apertus-8B-Instruct tests - FUNC_NAME_AS_KEY format
|
||||
// Format: <|tools_prefix|>[{"function_name": {...arguments...}}]<|tools_suffix|>
|
||||
{
|
||||
|
|
|
|||
|
|
@ -13,13 +13,10 @@
|
|||
#include <unordered_map>
|
||||
#include <map>
|
||||
#include <fstream>
|
||||
#include <cmath>
|
||||
#include <cctype>
|
||||
#include <algorithm>
|
||||
#include <filesystem>
|
||||
|
||||
// result of parsing --tensor-type option
|
||||
// (changes to this struct must be reflected in src/llama-quant.cpp)
|
||||
// changes to this struct must also be reflected in src/llama-quant.cpp
|
||||
struct tensor_type_option {
|
||||
std::string name;
|
||||
ggml_type type = GGML_TYPE_COUNT;
|
||||
|
|
@ -491,7 +488,6 @@ static bool parse_layer_prune(const char * data, std::vector<int> & prune_layers
|
|||
|
||||
int main(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
if (argc < 3) {
|
||||
usage(argv[0]);
|
||||
}
|
||||
|
|
@ -584,8 +580,16 @@ int main(int argc, char ** argv) {
|
|||
std::vector<std::string> imatrix_datasets;
|
||||
std::unordered_map<std::string, std::vector<float>> imatrix_data;
|
||||
int m_last_call = prepare_imatrix(imatrix_file, imatrix_datasets, included_weights, excluded_weights, imatrix_data);
|
||||
|
||||
std::vector<llama_model_imatrix_data> i_data;
|
||||
std::vector<llama_model_tensor_override> t_override;
|
||||
if (!imatrix_data.empty()) {
|
||||
params.imatrix = &imatrix_data;
|
||||
i_data.reserve(imatrix_data.size() + 1);
|
||||
for (const auto & kv : imatrix_data) {
|
||||
i_data.push_back({kv.first.c_str(), kv.second.data(), kv.second.size()});
|
||||
}
|
||||
i_data.push_back({nullptr, nullptr, 0}); // array terminator
|
||||
params.imatrix = i_data.data();
|
||||
{
|
||||
llama_model_kv_override kvo;
|
||||
std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_FILE);
|
||||
|
|
@ -603,7 +607,6 @@ int main(int argc, char ** argv) {
|
|||
kvo.val_str[127] = '\0';
|
||||
kv_overrides.emplace_back(std::move(kvo));
|
||||
}
|
||||
|
||||
{
|
||||
llama_model_kv_override kvo;
|
||||
std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES);
|
||||
|
|
@ -611,7 +614,6 @@ int main(int argc, char ** argv) {
|
|||
kvo.val_i64 = imatrix_data.size();
|
||||
kv_overrides.emplace_back(std::move(kvo));
|
||||
}
|
||||
|
||||
if (m_last_call > 0) {
|
||||
llama_model_kv_override kvo;
|
||||
std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_N_CHUNKS);
|
||||
|
|
@ -623,13 +625,19 @@ int main(int argc, char ** argv) {
|
|||
if (!kv_overrides.empty()) {
|
||||
kv_overrides.emplace_back();
|
||||
kv_overrides.back().key[0] = 0;
|
||||
params.kv_overrides = &kv_overrides;
|
||||
params.kv_overrides = kv_overrides.data();
|
||||
}
|
||||
if (!tensor_type_opts.empty()) {
|
||||
params.tensor_types = &tensor_type_opts;
|
||||
t_override.reserve(tensor_type_opts.size() + 1);
|
||||
for (const auto & tt : tensor_type_opts) {
|
||||
t_override.push_back({tt.name.c_str(), tt.type});
|
||||
}
|
||||
t_override.push_back({nullptr, GGML_TYPE_COUNT}); // array terminator
|
||||
params.tt_overrides = t_override.data();
|
||||
}
|
||||
if (!prune_layers.empty()) {
|
||||
params.prune_layers = &prune_layers;
|
||||
prune_layers.push_back(-1); // array terminator
|
||||
params.prune_layers = prune_layers.data();
|
||||
}
|
||||
|
||||
llama_backend_init();
|
||||
|
|
|
|||
|
|
@ -1196,6 +1196,10 @@ server_http_proxy::server_http_proxy(
|
|||
// disable Accept-Encoding to avoid compressed responses
|
||||
continue;
|
||||
}
|
||||
if (key == "Transfer-Encoding") {
|
||||
// the body is already decoded
|
||||
continue;
|
||||
}
|
||||
if (key == "Host" || key == "host") {
|
||||
bool is_default_port = (scheme == "https" && port == 443) || (scheme == "http" && port == 80);
|
||||
req.set_header(key, is_default_port ? host : host + ":" + std::to_string(port));
|
||||
|
|
|
|||
Loading…
Reference in New Issue