Compare commits

...

17 Commits

Author SHA1 Message Date
Eden 029a6abe30
Merge 2266f48d68 into 1d6d4cf7a5 2026-04-01 17:27:38 +02:00
Jonathan 1d6d4cf7a5
fix: tool call parsing for LFM2 and LFM2.5 models (#21242)
* fix: tool call parsing for LFM2 and LFM2.5 models'

* refactor: add test / break out lfm2 and lfm2.5 parsing logic
2026-04-01 16:22:44 +02:00
Georgi Gerganov 744c0c7310
llama : rotate activations for better quantization (#21038)
* llama : rotate activations for better quantization

* cont : rotate V more + refactor

* cont : rotate caches separately + support non-power-of-2 head sizes

* cont : simplify

* cont : add reference for V rotation

* cont : refactor

* cont : support context shift

* cont : consolidate

* cont : dedup + allow different types for the rotation matrix

* cont : add env variable to disable rotation

* cont : simplify attn rot kv cache logic + rename env

* cont : pre-compute the Hadamard matrices
2026-04-01 16:58:01 +03:00
Xuan-Son Nguyen 0356e33aaf
scripts: add function call test script (#21234)
* scripts: add function call test script

* add reasoning_content

* fix lint
2026-04-01 15:31:58 +02:00
Georgi Gerganov 6422036fcb sync : ggml 2026-04-01 16:03:17 +03:00
Georgi Gerganov 296bc0538b ggml : bump version to 0.9.10 (ggml/1454) 2026-04-01 16:03:17 +03:00
Neo Zhang 6b949d1078
sycl : support nvfp4 type in mul_mat (#21227) 2026-04-01 13:54:15 +03:00
Michael Wand 84f82e846c
ggml-cuda: Add generic NVFP4 MMQ kernel (#21074)
* Introduced NVFP4 generic MMQ kernel

* Added extra FP8 guard, hope to solve ci HIP failure

* Rename tiles and use HIP_FP8_AVAILABLE

* Removed remaning FP8 straggler and added const int

* Const

* Removed DECL_MMQ_CASE artifact

* Removed newline

* Removed space after else

* Changed HIP FP8 NVFP4 conversion gate

* Added new line to bottom of mmq.cu 270

* Removed extra spaces

* Removed single space in front of else on line 814

* Added NVFP4 to generate cu script so HIP can see it, further tightened logic

* Include generated mmq-instance-nvfp4.cu

* Added NVFP4 mmq to HIP Check ignore list

* Update ggml/src/ggml-cuda/mmq.cuh

Changed to Q3_K tile to read MMQ_MMA_TILE_X_K_NVFP4

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Update ggml/src/ggml-cuda/mmq.cuh

Changed to Q3_K tile to read MMQ_MMA_TILE_X_K_NVFP4 in tile assert

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Update ggml/src/ggml-cuda/mmq.cuh

Added function name ending for end if

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Added function names to closing endif

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2026-04-01 12:04:58 +02:00
Ettore Di Giacinto e1cb817483
memory: respect unified KV cache in hybrid memory for eval tasks (#21224)
The hybrid memory paths (`llama-memory-hybrid.cpp` and
`llama-memory-hybrid-iswa.cpp`) always used sequential equal split,
ignoring the unified KV cache flag. This caused hellaswag, winogrande,
and multiple-choice evaluations to fail on hybrid models (models with
both attention and recurrent/SSM layers, such as Qwen3.5-35B-A3B) with:

  split_equal: sequential split is not supported when there are
  coupled sequences in the input batch (you may need to use the
  -kvu flag)

PR #19954 fixed this for `llama-kv-cache-iswa.cpp` by automatically
enabling unified KV mode and setting n_parallel >= 4 for multi-choice
eval tasks. However, the hybrid memory paths were not updated.

This commit mirrors the iswa fix: use non-sequential split when KV
cache is unified (n_stream == 1), which is automatically set by
llama-perplexity for hellaswag/winogrande/multiple-choice since #19954.

Tested on Qwen3.5-35B-A3B (hybrid attention+SSM MoE model):
- HellaSwag: 83.0% (400 tasks)
- Winogrande: 74.5% (400 tasks)
- MMLU: 41.2%
- ARC-Challenge: 56.2%
- TruthfulQA: 37.7%
All previously failed with llama_decode() error.
2026-04-01 12:50:17 +03:00
uvos 88d5f8ffc3
CUDA/HIP: Fix kernel slection for mmvq mmid kernel to align host selection with device launch bounds (#21238)
The conditions cc == GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_ADA_LOVELACE and cc >= GGML_CUDA_CC_TURING match all non-nvidia devices. This causes us to attempt to launch the kernel for batch sizes with larger configurations than our launch bounds on HIP devices. This pr fixes the conditionals in get_mmvq_mmid_max_batch.

Fixes #21191
2026-04-01 10:21:20 +02:00
Georgi Gerganov d43375ff7f
ggml : fix RWKV ops thread assignment (#21226) 2026-04-01 11:10:25 +03:00
Taimur Ahmad 2b86e5cae6
ggml-cpu: fix fallback for RVV kernels without zvfh (#21157)
* ggml-cpu: refactor sgemm; fix rvv checks

* ggml-cpu: refactor rvv kernels; set zvfbfwma default to off
2026-04-01 11:10:03 +03:00
Anav Prasad 88458164c7
CUDA: Add Flash Attention Support for Head Dimension 512 (#20998)
* flash attention support for head dimension 512 added

* FA D=512 - match 576 configs, limit ncols2, revert vec cap

* fix HIP tile kernel build for D=512

* fix HIP tile kernel occupancy for D=512 on AMD

* Apply suggestions from code review

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* fix tile FA compilation

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2026-04-01 09:07:24 +02:00
Ed Addario 4951250235
llama : refactor llama_model_quantize_params to expose a pure C interface (#20346)
* Refactor llama_model_quantize_params to expose a pure C interface

* Restore comment and cleanup struct def

* Code review refactoring

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Code review refactoring

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-04-01 08:43:00 +03:00
Reese Levine 82764c341a
ggml webgpu: quantized buffers to u32 + wider browser/device support (#21046)
* Work towards removing bitcast

* Move rest of existing types over

* Add timeout back to wait and remove synchronous set_tensor/memset_tensor

* move to unpackf16 for wider compatibility

* cleanup

* Remove deadlock condition in free_bufs
2026-04-01 08:38:24 +03:00
Abhijit Ramesh 825eb91a66
ggml-webgpu: port all AOT operators to JIT (#20728)
* port cpy pipeline to shader lib with JIT compilation
 * port glu pipeline to shader lib with JIT compilation
 * port rope pipeline to shader lib with JIT compilation
 * port soft_max pipeline to shader lib with JIT compilation
 * removed unused functions from embed_wgsl.py which were used for
old AOT template expansion
2026-03-31 15:38:16 -07:00
許元豪 2266f48d68 cli, server: apply --prio process priority setting
The --prio flag was parsed but never applied in llama-cli and
llama-server. Only llama-completion and llama-bench called
set_process_priority(). Add the missing calls after backend
initialization so the flag takes effect in all tools.
2026-03-11 09:28:43 +08:00
60 changed files with 3427 additions and 1606 deletions

View File

@ -1274,11 +1274,12 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
return data;
}
// LFM2 format:
// - Reasoning: <think>{reasoning}</think> (optional, only if enable_thinking is true)
// - Content: text after reasoning (optional)
// - Tool calls: <|tool_call_start|>[function_name(arg1="value1", arg2="value2")]<|tool_call_end|>
// Tool calls can appear multiple times (parallel tool calls)
// LFM2 format: uses <|tool_list_start|>[...]<|tool_list_end|> in system prompt
// and <|tool_call_start|>[name(arg="val")]<|tool_call_end|> for tool calls.
// - Reasoning: <think>{reasoning}</think> (optional)
// - Content: text before a tool call (optional)
// - Tool calls: Python-style, e.g. [function_name(arg1="value1", arg2="value2")]
// Tool calls can appear multiple times (parallel tool calls supported)
static common_chat_params common_chat_params_init_lfm2(const common_chat_template & tmpl,
const autoparser::generation_params & inputs) {
common_chat_params data;
@ -1319,9 +1320,9 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
return generation_prompt + reasoning + p.content(p.rest()) + end;
}
auto tool_calls = p.rule("tool-calls",
p.trigger_rule("tool-call", p.literal(TOOL_CALL_START) +
p.trigger_rule("tool-call",
p.literal(TOOL_CALL_START) +
p.python_style_tool_calls(inputs.tools, inputs.parallel_tool_calls) +
p.literal(TOOL_CALL_END)
)
@ -1349,6 +1350,80 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
{ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, TOOL_CALL_START }
};
}
return data;
}
// LFM2.5 format: uses plain "List of tools: [...]" in system prompt, no wrapper tokens.
// Tool calls are bare [name(arg="val")], though model may optionally emit <|tool_call_start|>.
// - Reasoning: <think>{reasoning}</think> (optional)
// - Content: text before a tool call (optional)
// - Tool calls: Python-style, e.g. [function_name(arg1="value1", arg2="value2")]
// Tool calls can appear multiple times (parallel tool calls supported)
static common_chat_params common_chat_params_init_lfm2_5(const common_chat_template & tmpl,
const autoparser::generation_params & inputs) {
common_chat_params data;
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = true;
data.preserved_tokens = {
"<|tool_call_start|>",
"<|tool_call_end|>",
"<think>",
"</think>",
};
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
const std::string THINK_START = "<think>";
const std::string THINK_END = "</think>";
data.thinking_start_tag = THINK_START;
data.thinking_end_tag = THINK_END;
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
auto generation_prompt = p.prefix(inputs.generation_prompt, THINK_START);
auto end = p.end();
auto reasoning = p.eps();
if (extract_reasoning && inputs.enable_thinking) {
reasoning = p.optional(THINK_START + p.reasoning(p.until(THINK_END)) + THINK_END);
}
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
return generation_prompt + reasoning + p.content(p.rest()) + end;
}
auto tool_calls = p.rule("tool-calls",
p.trigger_rule("tool-call",
p.python_style_tool_calls(inputs.tools, inputs.parallel_tool_calls)
)
);
auto content = p.content(p.until_one_of({"<|tool_call_start|>", "["}));
auto maybe_start = p.optional(p.literal("<|tool_call_start|>"));
return generation_prompt + reasoning + content + maybe_start + tool_calls + end;
});
data.parser = parser.save();
if (include_grammar) {
data.grammar_lazy = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
auto schema = function.at("parameters");
builder.resolve_refs(schema);
});
parser.build_grammar(builder, data.grammar_lazy);
});
foreach_function(inputs.tools, [&](const json & tool) {
const std::string name = tool.at("function").at("name");
data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[" + name + "(" });
});
}
return data;
}
@ -1530,14 +1605,21 @@ static std::optional<common_chat_params> try_specialized_template(
return common_chat_params_init_kimi_k2(tmpl, params);
}
// LFM2 - uses <|tool_list_start|>/<|tool_list_end|> markers and <|tool_call_start|>[name(args)]<|tool_call_end|> format
// Detection: template has "<|tool_list_start|>" and "<|tool_list_end|>" markers
// LFM2 format detection: template uses <|tool_list_start|>[...]<|tool_list_end|> around the tool list
// and <|tool_call_start|>[...]<|tool_call_end|> around each tool call
if (src.find("<|tool_list_start|>") != std::string::npos &&
src.find("<|tool_list_end|>") != std::string::npos) {
LOG_DBG("Using specialized template: LFM2\n");
return common_chat_params_init_lfm2(tmpl, params);
}
// LFM2.5 format detection: template uses plain "List of tools: [...]" with no special tokens
if (src.find("List of tools: [") != std::string::npos &&
src.find("<|tool_list_start|>") == std::string::npos) {
LOG_DBG("Using specialized template: LFM2.5\n");
return common_chat_params_init_lfm2_5(tmpl, params);
}
// GigaChatV3 format detection
if (src.find("<|role_sep|>") != std::string::npos &&
src.find("<|message_sep|>") != std::string::npos &&

View File

@ -4,7 +4,7 @@ project("ggml" C CXX ASM)
### GGML Version
set(GGML_VERSION_MAJOR 0)
set(GGML_VERSION_MINOR 9)
set(GGML_VERSION_PATCH 9)
set(GGML_VERSION_PATCH 10)
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)
@ -166,15 +166,16 @@ if (NOT MSVC)
option(GGML_AMX_INT8 "ggml: enable AMX-INT8" OFF)
option(GGML_AMX_BF16 "ggml: enable AMX-BF16" OFF)
endif()
option(GGML_LASX "ggml: enable lasx" ON)
option(GGML_LSX "ggml: enable lsx" ON)
option(GGML_RVV "ggml: enable rvv" ON)
option(GGML_RV_ZFH "ggml: enable riscv zfh" ON)
option(GGML_RV_ZVFH "ggml: enable riscv zvfh" ON)
option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON)
option(GGML_RV_ZIHINTPAUSE "ggml: enable riscv zihintpause " ON)
option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF)
option(GGML_VXE "ggml: enable vxe" ${GGML_NATIVE})
option(GGML_LASX "ggml: enable lasx" ON)
option(GGML_LSX "ggml: enable lsx" ON)
option(GGML_RVV "ggml: enable rvv" ON)
option(GGML_RV_ZFH "ggml: enable riscv zfh" ON)
option(GGML_RV_ZVFH "ggml: enable riscv zvfh" ON)
option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON)
option(GGML_RV_ZIHINTPAUSE "ggml: enable riscv zihintpause" ON)
option(GGML_RV_ZVFBFWMA "ggml: enable riscv zvfbfwma" OFF)
option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF)
option(GGML_VXE "ggml: enable vxe" ${GGML_NATIVE})
option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF)
set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM")

View File

@ -2350,11 +2350,15 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_FLASH_ATTN_BACK:
case GGML_OP_SSM_CONV:
case GGML_OP_SSM_SCAN:
{
n_tasks = n_threads;
} break;
case GGML_OP_RWKV_WKV6:
case GGML_OP_GATED_LINEAR_ATTN:
case GGML_OP_RWKV_WKV7:
{
n_tasks = n_threads;
const int64_t n_heads = node->src[1]->ne[1];
n_tasks = MIN(n_threads, n_heads);
} break;
case GGML_OP_WIN_PART:
case GGML_OP_WIN_UNPART:

View File

@ -180,44 +180,49 @@ inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
}
#endif
#if defined(__riscv_zvfh)
template <>
inline vfloat32m1_t madd(vfloat16mf2_t a, vfloat16mf2_t b, vfloat32m1_t c) {
return __riscv_vfwmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
}
inline vfloat32m2_t madd(vfloat16m1_t a, vfloat16m1_t b, vfloat32m2_t c) {
return __riscv_vfwmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
}
inline vfloat32m4_t madd(vfloat16m2_t a, vfloat16m2_t b, vfloat32m4_t c) {
return __riscv_vfwmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
}
inline vfloat32m8_t madd(vfloat16m4_t a, vfloat16m4_t b, vfloat32m8_t c) {
return __riscv_vfwmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
}
inline vfloat32m1_t madd(vfloat32m1_t a, vfloat32m1_t b, vfloat32m1_t c) {
#if defined(__riscv_v_intrinsic)
template <> inline vfloat32m1_t madd(vfloat32m1_t a, vfloat32m1_t b, vfloat32m1_t c) {
return __riscv_vfmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
}
inline vfloat32m2_t madd(vfloat32m2_t a, vfloat32m2_t b, vfloat32m2_t c) {
template <> inline vfloat32m2_t madd(vfloat32m2_t a, vfloat32m2_t b, vfloat32m2_t c) {
return __riscv_vfmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
}
inline vfloat32m4_t madd(vfloat32m4_t a, vfloat32m4_t b, vfloat32m4_t c) {
template <> inline vfloat32m4_t madd(vfloat32m4_t a, vfloat32m4_t b, vfloat32m4_t c) {
return __riscv_vfmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
}
inline vfloat32m8_t madd(vfloat32m8_t a, vfloat32m8_t b, vfloat32m8_t c) {
template <> inline vfloat32m8_t madd(vfloat32m8_t a, vfloat32m8_t b, vfloat32m8_t c) {
return __riscv_vfmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
}
#endif
#if defined(__riscv_zvfh)
template <> inline vfloat32m1_t madd(vfloat16mf2_t a, vfloat16mf2_t b, vfloat32m1_t c) {
return __riscv_vfwmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
}
template <> inline vfloat32m2_t madd(vfloat16m1_t a, vfloat16m1_t b, vfloat32m2_t c) {
return __riscv_vfwmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
}
template <> inline vfloat32m4_t madd(vfloat16m2_t a, vfloat16m2_t b, vfloat32m4_t c) {
return __riscv_vfwmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
}
template <> inline vfloat32m8_t madd(vfloat16m4_t a, vfloat16m4_t b, vfloat32m8_t c) {
return __riscv_vfwmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
}
#endif
#if defined(__riscv_zvfbfwma)
inline vfloat32m1_t madd(vbfloat16mf2_t a, vbfloat16mf2_t b, vfloat32m1_t c) {
template <> inline vfloat32m1_t madd(vbfloat16mf2_t a, vbfloat16mf2_t b, vfloat32m1_t c) {
return __riscv_vfwmaccbf16_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1());
}
inline vfloat32m2_t madd(vbfloat16m1_t a, vbfloat16m1_t b, vfloat32m2_t c) {
template <> inline vfloat32m2_t madd(vbfloat16m1_t a, vbfloat16m1_t b, vfloat32m2_t c) {
return __riscv_vfwmaccbf16_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2());
}
inline vfloat32m4_t madd(vbfloat16m2_t a, vbfloat16m2_t b, vfloat32m4_t c) {
template <> inline vfloat32m4_t madd(vbfloat16m2_t a, vbfloat16m2_t b, vfloat32m4_t c) {
return __riscv_vfwmaccbf16_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4());
}
template <> inline vfloat32m8_t madd(vbfloat16m4_t a, vbfloat16m4_t b, vfloat32m8_t c) {
return __riscv_vfwmaccbf16_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8());
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -272,7 +277,7 @@ inline float hsum(__m512 x) {
}
#endif // __AVX512F__
#if defined(__riscv_zvfh)
#if defined(__riscv_v_intrinsic)
inline float hsum(vfloat32m1_t x) {
return __riscv_vfmv_f_s_f32m1_f32(
__riscv_vfredusum_vs_f32m1_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m1()));
@ -379,19 +384,7 @@ template <> inline __m256bh load(const float *p) {
}
#endif
#if defined(__riscv_zvfh)
template <> inline vfloat16mf2_t load(const ggml_fp16_t *p) {
return __riscv_vle16_v_f16mf2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16mf2());
}
template <> inline vfloat16m1_t load(const ggml_fp16_t *p) {
return __riscv_vle16_v_f16m1(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m1());
}
template <> inline vfloat16m2_t load(const ggml_fp16_t *p) {
return __riscv_vle16_v_f16m2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m2());
}
template <> inline vfloat16m4_t load(const ggml_fp16_t *p) {
return __riscv_vle16_v_f16m4(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m4());
}
#if defined(__riscv_v_intrinsic)
template <> inline vfloat32m1_t load(const float *p) {
return __riscv_vle32_v_f32m1(p, __riscv_vsetvlmax_e32m1());
}
@ -406,6 +399,21 @@ template <> inline vfloat32m8_t load(const float *p) {
}
#endif
#if defined(__riscv_zvfh)
template <> inline vfloat16mf2_t load(const ggml_fp16_t *p) {
return __riscv_vle16_v_f16mf2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16mf2());
}
template <> inline vfloat16m1_t load(const ggml_fp16_t *p) {
return __riscv_vle16_v_f16m1(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m1());
}
template <> inline vfloat16m2_t load(const ggml_fp16_t *p) {
return __riscv_vle16_v_f16m2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m2());
}
template <> inline vfloat16m4_t load(const ggml_fp16_t *p) {
return __riscv_vle16_v_f16m4(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m4());
}
#endif
#if defined(__riscv_zvfbfwma)
template <> inline vbfloat16mf2_t load(const ggml_bf16_t *p) {
return __riscv_vle16_v_bf16mf2(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16mf2());
@ -416,23 +424,14 @@ template <> inline vbfloat16m1_t load(const ggml_bf16_t *p) {
template <> inline vbfloat16m2_t load(const ggml_bf16_t *p) {
return __riscv_vle16_v_bf16m2(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m2());
}
template <> inline vbfloat16m4_t load(const ggml_bf16_t *p) {
return __riscv_vle16_v_bf16m4(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m4());
}
#endif
#if defined(__riscv_zvfh)
#if defined(__riscv_v_intrinsic)
template <typename T> T set_zero();
template <> inline vfloat16mf2_t set_zero() {
return __riscv_vfmv_v_f_f16mf2(0, __riscv_vsetvlmax_e16mf2());
}
template <> inline vfloat16m1_t set_zero() {
return __riscv_vfmv_v_f_f16m1(0, __riscv_vsetvlmax_e16m1());
}
template <> inline vfloat16m2_t set_zero() {
return __riscv_vfmv_v_f_f16m2(0, __riscv_vsetvlmax_e16m2());
}
template <> inline vfloat16m4_t set_zero() {
return __riscv_vfmv_v_f_f16m4(0, __riscv_vsetvlmax_e16m4());
}
template <> inline vfloat32m1_t set_zero() {
return __riscv_vfmv_v_f_f32m1(0.0f, __riscv_vsetvlmax_e32m1());
}
@ -449,14 +448,22 @@ template <> inline vfloat32m8_t set_zero() {
#if defined(__riscv_v_intrinsic)
template <typename T> size_t vlmax() {
if constexpr (std::is_same_v<T, vfloat16mf2_t>) { return __riscv_vsetvlmax_e16mf2(); }
else if constexpr (std::is_same_v<T, vfloat16m1_t>) { return __riscv_vsetvlmax_e16m1(); }
else if constexpr (std::is_same_v<T, vfloat16m2_t>) { return __riscv_vsetvlmax_e16m2(); }
else if constexpr (std::is_same_v<T, vfloat16m4_t>) { return __riscv_vsetvlmax_e16m4(); }
else if constexpr (std::is_same_v<T, vfloat32m1_t>) { return __riscv_vsetvlmax_e32m1(); }
if constexpr (std::is_same_v<T, vfloat32m1_t>) { return __riscv_vsetvlmax_e32m1(); }
else if constexpr (std::is_same_v<T, vfloat32m2_t>) { return __riscv_vsetvlmax_e32m2(); }
else if constexpr (std::is_same_v<T, vfloat32m4_t>) { return __riscv_vsetvlmax_e32m4(); }
else if constexpr (std::is_same_v<T, vfloat32m8_t>) { return __riscv_vsetvlmax_e32m8(); }
#if defined (__riscv_zvfh)
else if constexpr (std::is_same_v<T, vfloat16mf2_t>) { return __riscv_vsetvlmax_e16mf2(); }
else if constexpr (std::is_same_v<T, vfloat16m1_t>) { return __riscv_vsetvlmax_e16m1(); }
else if constexpr (std::is_same_v<T, vfloat16m2_t>) { return __riscv_vsetvlmax_e16m2(); }
else if constexpr (std::is_same_v<T, vfloat16m4_t>) { return __riscv_vsetvlmax_e16m4(); }
#endif
#if defined (__riscv_zvfbfwma)
else if constexpr (std::is_same_v<T, vbfloat16mf2_t>) { return __riscv_vsetvlmax_e16mf2(); }
else if constexpr (std::is_same_v<T, vbfloat16m1_t>) { return __riscv_vsetvlmax_e16m1(); }
else if constexpr (std::is_same_v<T, vbfloat16m2_t>) { return __riscv_vsetvlmax_e16m2(); }
else if constexpr (std::is_same_v<T, vbfloat16m4_t>) { return __riscv_vsetvlmax_e16m4(); }
#endif
return 0;
}
#endif
@ -3740,7 +3747,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
params->ith, params->nth};
tb.matmul(m, n);
return true;
#elif defined(__riscv_zvfh)
#elif defined(__riscv_v_intrinsic)
#if LMUL == 1
tinyBLAS_RVV<vfloat32m1_t, vfloat32m1_t, float, float, float> tb{ params,
k, (const float *)A, lda,
@ -3804,23 +3811,25 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
return true;
}
#elif defined(__riscv_zvfbfwma)
#if LMUL == 1
tinyBLAS_RVV<vfloat32m1_t, vbfloat16mf2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
k, (const ggml_bf16_t *)A, lda,
(const ggml_bf16_t *)B, ldb,
(float *)C, ldc};
#elif LMUL == 2
tinyBLAS_RVV<vfloat32m2_t, vbfloat16m1_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
k, (const ggml_bf16_t *)A, lda,
(const ggml_bf16_t *)B, ldb,
(float *)C, ldc};
#else // LMUL = 4
tinyBLAS_RVV<vfloat32m4_t, vbfloat16m2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
k, (const ggml_bf16_t *)A, lda,
(const ggml_bf16_t *)B, ldb,
(float *)C, ldc};
#endif
return tb.matmul(m, n);
if (Btype == GGML_TYPE_BF16) {
#if LMUL == 1
tinyBLAS_RVV<vfloat32m1_t, vbfloat16mf2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
k, (const ggml_bf16_t *)A, lda,
(const ggml_bf16_t *)B, ldb,
(float *)C, ldc};
#elif LMUL == 2
tinyBLAS_RVV<vfloat32m2_t, vbfloat16m1_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
k, (const ggml_bf16_t *)A, lda,
(const ggml_bf16_t *)B, ldb,
(float *)C, ldc};
#else // LMUL = 4
tinyBLAS_RVV<vfloat32m4_t, vbfloat16m2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params,
k, (const ggml_bf16_t *)A, lda,
(const ggml_bf16_t *)B, ldb,
(float *)C, ldc};
#endif
return tb.matmul(m, n);
}
#endif
return false;
}

View File

@ -9953,13 +9953,9 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
const int ith = params->ith;
const int nth = params->nth;
if (ith >= HEADS) {
return;
}
const int h_start = (HEADS * ith) / nth;
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
(HEADS * (ith + 1)) / nth : HEADS;
const int h_start = (HEADS * (ith )) / nth;
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
(HEADS * (ith + 1)) / nth : HEADS;
float * k = (float *) dst->src[0]->data;
float * v = (float *) dst->src[1]->data;
@ -10170,13 +10166,9 @@ static void ggml_compute_forward_gla_f32(
const int ith = params->ith;
const int nth = params->nth;
if (ith >= HEADS) {
return;
}
const int h_start = (HEADS * ith) / nth;
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
(HEADS * (ith + 1)) / nth : HEADS;
const int h_start = (HEADS * (ith )) / nth;
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
(HEADS * (ith + 1)) / nth : HEADS;
float * k = (float *) dst->src[0]->data;
float * v = (float *) dst->src[1]->data;
@ -10633,13 +10625,9 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
const int ith = params->ith;
const int nth = params->nth;
if (ith >= HEADS) {
return;
}
const int h_start = (HEADS * ith) / nth;
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
(HEADS * (ith + 1)) / nth : HEADS;
const int h_start = (HEADS * (ith )) / nth;
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
(HEADS * (ith + 1)) / nth : HEADS;
float * r = (float *) dst->src[0]->data;
float * w = (float *) dst->src[1]->data;

View File

@ -126,7 +126,7 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
const int ggml_f16_epr = sve_register_length / 16; // running when 16
const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers
const int np = (n & ~(ggml_f16_step - 1));
int np = (n & ~(ggml_f16_step - 1));
svfloat16_t sum_00 = svdup_n_f16(0.0f);
svfloat16_t sum_01 = svdup_n_f16(0.0f);
@ -224,71 +224,75 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
}
GGML_F16x_VEC_REDUCE(sumf[0], sum_00, sum_01, sum_02, sum_03);
GGML_F16x_VEC_REDUCE(sumf[1], sum_10, sum_11, sum_12, sum_13);
np = n;
#elif defined(__riscv_v_intrinsic)
#if defined(__riscv_zvfh)
size_t vl = __riscv_vsetvlmax_e32m4();
#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh)
size_t vl = __riscv_vsetvlmax_e32m4();
// initialize accumulators to all zeroes
vfloat32m4_t vsum0_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
vfloat32m4_t vsum0_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
vfloat32m4_t vsum1_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
vfloat32m4_t vsum1_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
// initialize accumulators to all zeroes
vfloat32m4_t vsum0_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
vfloat32m4_t vsum0_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
vfloat32m4_t vsum1_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
vfloat32m4_t vsum1_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
// calculate step size
const size_t epr = __riscv_vsetvlmax_e16m2();
const size_t step = epr * 2;
int np = (n & ~(step - 1));
// calculate step size
const size_t epr = __riscv_vsetvlmax_e16m2();
const size_t step = epr * 2;
const int np = (n & ~(step - 1));
// unroll by 2 along the row dimension
for (int i = 0; i < np; i += step) {
vfloat16m2_t ay0 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), epr);
vfloat16m2_t ax0_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), epr);
vfloat16m2_t ax1_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), epr);
vsum0_0 = __riscv_vfwmacc_vv_f32m4(vsum0_0, ax0_0, ay0, epr);
vsum1_0 = __riscv_vfwmacc_vv_f32m4(vsum1_0, ax1_0, ay0, epr);
// unroll by 2 along the row dimension
for (int i = 0; i < np; i += step) {
vfloat16m2_t ay0 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), epr);
vfloat16m2_t ax0_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), epr);
vfloat16m2_t ax1_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), epr);
vsum0_0 = __riscv_vfwmacc_vv_f32m4(vsum0_0, ax0_0, ay0, epr);
vsum1_0 = __riscv_vfwmacc_vv_f32m4(vsum1_0, ax1_0, ay0, epr);
vfloat16m2_t ay1 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i + epr), epr);
vfloat16m2_t ax0_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i + epr), epr);
vfloat16m2_t ax1_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i + epr), epr);
vsum0_1 = __riscv_vfwmacc_vv_f32m4(vsum0_1, ax0_1, ay1, epr);
vsum1_1 = __riscv_vfwmacc_vv_f32m4(vsum1_1, ax1_1, ay1, epr);
}
vfloat16m2_t ay1 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i + epr), epr);
vfloat16m2_t ax0_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i + epr), epr);
vfloat16m2_t ax1_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i + epr), epr);
vsum0_1 = __riscv_vfwmacc_vv_f32m4(vsum0_1, ax0_1, ay1, epr);
vsum1_1 = __riscv_vfwmacc_vv_f32m4(vsum1_1, ax1_1, ay1, epr);
}
vfloat32m4_t vsum0 = __riscv_vfadd_vv_f32m4(vsum0_0, vsum0_1, vl);
vfloat32m4_t vsum1 = __riscv_vfadd_vv_f32m4(vsum1_0, vsum1_1, vl);
vfloat32m4_t vsum0 = __riscv_vfadd_vv_f32m4(vsum0_0, vsum0_1, vl);
vfloat32m4_t vsum1 = __riscv_vfadd_vv_f32m4(vsum1_0, vsum1_1, vl);
// leftovers
for (int i = np; i < n; i += vl) {
vl = __riscv_vsetvl_e16m2(n - i);
vfloat16m2_t ay = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), vl);
vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), vl);
vfloat16m2_t ax1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), vl);
// leftovers
for (int i = np; i < n; i += vl) {
vl = __riscv_vsetvl_e16m2(n - i);
vfloat16m2_t ay = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), vl);
vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), vl);
vfloat16m2_t ax1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), vl);
vsum0 = __riscv_vfwmacc_vv_f32m4(vsum0, ax0, ay, vl);
vsum1 = __riscv_vfwmacc_vv_f32m4(vsum1, ax1, ay, vl);
}
vsum0 = __riscv_vfwmacc_vv_f32m4(vsum0, ax0, ay, vl);
vsum1 = __riscv_vfwmacc_vv_f32m4(vsum1, ax1, ay, vl);
}
// reduce
vl = __riscv_vsetvlmax_e32m2();
vfloat32m2_t acc0_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum0, 0),
__riscv_vget_v_f32m4_f32m2(vsum0, 1), vl);
vl = __riscv_vsetvlmax_e32m1();
vfloat32m1_t acc0_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc0_0, 0),
__riscv_vget_v_f32m2_f32m1(acc0_0, 1), vl);
vfloat32m1_t redsum0 = __riscv_vfredusum_vs_f32m1_f32m1(
acc0_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
vl = __riscv_vsetvlmax_e32m2();
vfloat32m2_t acc1_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum1, 0),
__riscv_vget_v_f32m4_f32m2(vsum1, 1), vl);
vl = __riscv_vsetvlmax_e32m1();
vfloat32m1_t acc1_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc1_0, 0),
__riscv_vget_v_f32m2_f32m1(acc1_0, 1), vl);
vfloat32m1_t redsum1 = __riscv_vfredusum_vs_f32m1_f32m1(
acc1_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
sumf[0] = __riscv_vfmv_f_s_f32m1_f32(redsum0);
sumf[1] = __riscv_vfmv_f_s_f32m1_f32(redsum1);
// reduce
vl = __riscv_vsetvlmax_e32m2();
vfloat32m2_t acc0_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum0, 0),
__riscv_vget_v_f32m4_f32m2(vsum0, 1), vl);
vl = __riscv_vsetvlmax_e32m1();
vfloat32m1_t acc0_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc0_0, 0),
__riscv_vget_v_f32m2_f32m1(acc0_0, 1), vl);
vfloat32m1_t redsum0 = __riscv_vfredusum_vs_f32m1_f32m1(
acc0_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
vl = __riscv_vsetvlmax_e32m2();
vfloat32m2_t acc1_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum1, 0),
__riscv_vget_v_f32m4_f32m2(vsum1, 1), vl);
vl = __riscv_vsetvlmax_e32m1();
vfloat32m1_t acc1_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc1_0, 0),
__riscv_vget_v_f32m2_f32m1(acc1_0, 1), vl);
vfloat32m1_t redsum1 = __riscv_vfredusum_vs_f32m1_f32m1(
acc1_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
sumf[0] = __riscv_vfmv_f_s_f32m1_f32(redsum0);
sumf[1] = __riscv_vfmv_f_s_f32m1_f32(redsum1);
np = n;
#else
const int np = 0;
#endif
#else
const int np = (n & ~(GGML_F16_STEP - 1));
@ -313,21 +317,17 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
}
// leftovers
for (int i = np; i < n; ++i) {
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
}
}
#endif
#else
for (int i = 0; i < n; ++i) {
// scalar path
const int np = 0;
#endif
// scalar and leftovers
for (int i = np; i < n; ++i) {
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
}
}
#endif
for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
s[i] = (float)sumf[i];
@ -532,40 +532,45 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y,
svst1_f16(pg, (__fp16 *)(y + np2), hy);
}
np = n;
#elif defined(__riscv_zvfh) // implies __riscv_v_intrinsic
const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v);
const _Float16 scale = *(const _Float16*)(&s);
#elif defined(__riscv_v_intrinsic) // implies __riscv_v_intrinsic
#if defined (__riscv_zvfh)
const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v);
const _Float16 scale = *(const _Float16*)(&s);
// calculate step size
const int epr = __riscv_vsetvlmax_e16m4();
const int step = epr * 2;
int np = (n & ~(step - 1));
// calculate step size
const int epr = __riscv_vsetvlmax_e16m4();
const int step = epr * 2;
int np = (n & ~(step - 1));
// unroll by 2
for (int i = 0; i < np; i += step) {
vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, epr);
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr);
ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, epr);
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr);
__asm__ __volatile__ ("" ::: "memory");
// unroll by 2
for (int i = 0; i < np; i += step) {
vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, epr);
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr);
ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, epr);
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr);
__asm__ __volatile__ ("" ::: "memory");
vfloat16m4_t ax1 = __riscv_vle16_v_f16m4((const _Float16*)x + i + epr, epr);
vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr);
ay1 = __riscv_vfmacc_vf_f16m4(ay1, scale, ax1, epr);
__riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr);
__asm__ __volatile__ ("" ::: "memory");
}
vfloat16m4_t ax1 = __riscv_vle16_v_f16m4((const _Float16*)x + i + epr, epr);
vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr);
ay1 = __riscv_vfmacc_vf_f16m4(ay1, scale, ax1, epr);
__riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr);
__asm__ __volatile__ ("" ::: "memory");
}
// leftovers
int vl;
for (int i = np; i < n; i += vl) {
vl = __riscv_vsetvl_e16m4(n - i);
vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, vl);
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl);
ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, vl);
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl);
}
np = n;
// leftovers
int vl;
for (int i = np; i < n; i += vl) {
vl = __riscv_vsetvl_e16m4(n - i);
vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, vl);
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl);
ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, vl);
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl);
}
np = n;
#else
// fall to scalar path
const int np = 0;
#endif
#elif defined(GGML_SIMD)
const int np = (n & ~(GGML_F16_STEP - 1));
@ -584,10 +589,11 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y,
}
}
#else
// scalar path
const int np = 0;
#endif
// leftovers
// scalar and leftovers
for (int i = np; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
}
@ -785,7 +791,7 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float
const int ggml_f16_step = 2 * ggml_f16_epr;
GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
const int np = (n & ~(ggml_f16_step - 1));
int np = (n & ~(ggml_f16_step - 1));
svfloat16_t ay1, ay2;
for (int i = 0; i < np; i += ggml_f16_step) {
@ -805,36 +811,43 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float
svfloat16_t out = svmul_f16_m(pg, hy, vx);
svst1_f16(pg, (__fp16 *)(y + np), out);
}
#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh)
const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v);
const _Float16 scale = *(const _Float16*)(&s);
np = n;
#elif defined(__riscv_v_intrinsic)
#if defined(__riscv_zvfh)
const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v);
const _Float16 scale = *(const _Float16*)(&s);
// calculate step size
const int epr = __riscv_vsetvlmax_e16m4();
const int step = epr * 2;
const int np = (n & ~(step - 1));
// calculate step size
const int epr = __riscv_vsetvlmax_e16m4();
const int step = epr * 2;
int np = (n & ~(step - 1));
// unroll by 2
for (int i = 0; i < np; i += step) {
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr);
ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, epr);
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr);
__asm__ __volatile__ ("" ::: "memory");
// unroll by 2
for (int i = 0; i < np; i += step) {
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr);
ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, epr);
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr);
__asm__ __volatile__ ("" ::: "memory");
vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr);
ay1 = __riscv_vfmul_vf_f16m4(ay1, scale, epr);
__riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr);
__asm__ __volatile__ ("" ::: "memory");
}
vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr);
ay1 = __riscv_vfmul_vf_f16m4(ay1, scale, epr);
__riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr);
__asm__ __volatile__ ("" ::: "memory");
}
// leftovers
int vl;
for (int i = np; i < n; i += vl) {
vl = __riscv_vsetvl_e16m4(n - i);
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl);
ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, vl);
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl);
}
// leftovers
int vl;
for (int i = np; i < n; i += vl) {
vl = __riscv_vsetvl_e16m4(n - i);
vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl);
ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, vl);
__riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl);
}
np = n;
#else
// fall to scalar path
const int np = 0;
#endif
#elif defined(GGML_SIMD)
const int np = (n & ~(GGML_F16_STEP - 1));
@ -850,17 +863,14 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
}
}
// leftovers
#else
// scalar path
const int np = 0;
#endif
// scalar and leftovers
for (int i = np; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
}
#else
// scalar
for (int i = 0; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
}
#endif
}
inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); }

View File

@ -800,19 +800,32 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
}
static __device__ __forceinline__ float ggml_cuda_ue4m3_to_fp32(uint8_t x) {
#ifdef FP8_AVAILABLE
const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation.
#if defined(GGML_USE_HIP) && defined(CDNA3)
// ROCm dose not support fp8 in software on devices with fp8 hardware,
#if defined(GGML_USE_HIP) && defined(CDNA3) && defined(FP8_AVAILABLE) && HIP_VERSION >= 60200000
// ROCm does not support fp8 in software on devices with fp8 hardware,
// but CDNA3 supports only e4m3_fnuz (no inf).
const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation.
const __hip_fp8_e4m3_fnuz xf = *reinterpret_cast<const __hip_fp8_e4m3_fnuz *>(&bits);
#else
const __nv_fp8_e4m3 xf = *reinterpret_cast<const __nv_fp8_e4m3 *>(&bits);
#endif // defined(GGML_USE_HIP) && defined(GGML_USE_HIP)
return static_cast<float>(xf) / 2;
#else
NO_DEVICE_CODE;
#endif // FP8_AVAILABLE
#if defined(FP8_AVAILABLE) && !defined(GGML_USE_HIP)
const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation.
const __nv_fp8_e4m3 xf = *reinterpret_cast<const __nv_fp8_e4m3 *>(&bits);
return static_cast<float>(xf) / 2;
#else
if (x == 0 || (x == 0x7F && x != 0xFF)) { // Convert NaN to 0.0f
return 0.0f;
}
const int exp = (x >> 3) & 0xF;
const int man = x & 0x7;
float raw;
if (exp == 0) {
raw = ldexpf((float) man, -9);
} else {
raw = ldexpf(1.0f + (float) man / 8.0f, exp - 7);
}
return static_cast<float>(raw / 2);
#endif // defined(FP8_AVAILABLE) && !defined(GGML_USE_HIP)
#endif // defined(GGML_USE_HIP) && defined(CDNA3) && defined(FP8_AVAILABLE) && HIP_VERSION >= 60200000
}
__device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) {

View File

@ -66,6 +66,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
@ -80,6 +85,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
@ -89,6 +99,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
}
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 64, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 64, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 64, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 64, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false);
@ -103,6 +118,10 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 128, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
@ -1552,7 +1571,7 @@ static __global__ void flash_attn_ext_f16(
#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE))
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
if (use_logit_softcap && !(DKQ == 128 || DKQ == 256 || DKQ == 512)) {
NO_DEVICE_CODE;
return;
}
@ -1815,6 +1834,15 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 64)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 64)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
extern DECL_FATTN_MMA_F16_CASE(512, 512, 2, 4);
extern DECL_FATTN_MMA_F16_CASE(512, 512, 4, 4);
extern DECL_FATTN_MMA_F16_CASE(512, 512, 8, 4);
extern DECL_FATTN_MMA_F16_CASE(512, 512, 16, 4);
extern DECL_FATTN_MMA_F16_CASE(512, 512, 1, 8);
extern DECL_FATTN_MMA_F16_CASE(512, 512, 2, 8);
extern DECL_FATTN_MMA_F16_CASE(512, 512, 4, 8);
extern DECL_FATTN_MMA_F16_CASE(512, 512, 8, 8);
// The number of viable configurations for Deepseek is very limited:
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);

View File

@ -38,6 +38,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst);
} break;
case 512: {
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case<512, 512>(ctx, dst);
} break;
case 576: {
GGML_ASSERT(V->ne[0] == 512);
ggml_cuda_flash_attn_ext_tile_case<576, 512>(ctx, dst);

View File

@ -68,6 +68,10 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
@ -124,6 +128,10 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
@ -187,6 +195,11 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 32, 512, 1, 128, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
@ -251,6 +264,11 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 4, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 32, 256, 2, 128, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
@ -767,7 +785,7 @@ static __global__ void flash_attn_tile(
#ifdef GGML_USE_WMMA_FATTN
(ncols2 != 1 && DV != 40 && DV != 72 && DV != 512) ||
#endif // GGML_USE_WMMA_FATTN
(use_logit_softcap && !(DV == 128 || DV == 256))
(use_logit_softcap && !(DV == 128 || DV == 256 || DV == 512))
) {
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
max_bias, m0, m1, n_head_log2, logit_softcap,
@ -1192,7 +1210,7 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX;
const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;
if constexpr (DV == 512) {
if constexpr (DKQ == 576) {
if (use_gqa_opt && gqa_ratio % 16 == 0) {
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
return;
@ -1203,7 +1221,7 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
}
}
if constexpr (DV <= 256) {
if constexpr (DKQ <= 512) {
if (use_gqa_opt && gqa_ratio % 8 == 0) {
launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
return;
@ -1214,13 +1232,15 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
return;
}
if (use_gqa_opt && gqa_ratio % 2 == 0) {
launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst);
if constexpr (DV <= 256) {
if (use_gqa_opt && gqa_ratio % 2 == 0) {
launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst);
return;
}
launch_fattn_tile_switch_ncols1<DKQ, DV, 1, use_logit_softcap>(ctx, dst);
return;
}
launch_fattn_tile_switch_ncols1<DKQ, DV, 1, use_logit_softcap>(ctx, dst);
return;
}
GGML_ABORT("fatal error");
}
@ -1255,4 +1275,5 @@ extern DECL_FATTN_TILE_CASE( 96, 96);
extern DECL_FATTN_TILE_CASE(112, 112);
extern DECL_FATTN_TILE_CASE(128, 128);
extern DECL_FATTN_TILE_CASE(256, 256);
extern DECL_FATTN_TILE_CASE(512, 512);
extern DECL_FATTN_TILE_CASE(576, 512);

View File

@ -135,6 +135,10 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
GGML_ASSERT(V->ne[0] == 256);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
break;
case 512:
GGML_ASSERT(V->ne[0] == 512);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<512, 512>(ctx, dst);
break;
case 576: {
// For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
GGML_ASSERT(V->ne[0] == 512);
@ -336,7 +340,8 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
case 128:
case 112:
case 256:
if (V->ne[0] != K->ne[0]) {
case 512:
if (!gqa_opt_applies) {
return BEST_FATTN_KERNEL_NONE;
}
break;
@ -424,7 +429,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
}
// Use the WMMA kernel if possible:
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 576) {
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 512 && Q->ne[0] != 576) {
if (can_use_vector_kernel && Q->ne[1] <= 2) {
return BEST_FATTN_KERNEL_VEC;
}
@ -457,7 +462,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
}
// Use MFMA flash attention for CDNA (MI100+):
if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 576) {
if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 512 && Q->ne[0] != 576) {
const int64_t eff_nq = Q->ne[1] * (gqa_opt_applies ? gqa_ratio : 1);
// MMA vs tile crossover benchmarked on MI300X @ d32768:
// hsk=64 (gqa=4): MMA wins at eff >= 128 (+11%)

View File

@ -4791,9 +4791,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_MXFP4:
#ifdef FP8_AVAILABLE
case GGML_TYPE_NVFP4:
#endif // FP8_AVAILABLE
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:

View File

@ -23,6 +23,9 @@ static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, con
case GGML_TYPE_MXFP4:
mul_mat_q_case<GGML_TYPE_MXFP4>(ctx, args, stream);
break;
case GGML_TYPE_NVFP4:
mul_mat_q_case<GGML_TYPE_NVFP4>(ctx, args, stream);
break;
case GGML_TYPE_Q2_K:
mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream);
break;
@ -273,6 +276,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@ -362,5 +366,4 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
}
return (!GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}

View File

@ -68,6 +68,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
return MMQ_Q8_1_DS_LAYOUT_D4;
case GGML_TYPE_MXFP4:
return MMQ_Q8_1_DS_LAYOUT_D4;
case GGML_TYPE_NVFP4:
return MMQ_Q8_1_DS_LAYOUT_D4;
case GGML_TYPE_Q2_K:
return MMQ_Q8_1_DS_LAYOUT_D2S6;
case GGML_TYPE_Q3_K:
@ -189,6 +191,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1;
case GGML_TYPE_NVFP4: return MMQ_DP4A_TXS_Q8_0_16;
case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
@ -206,12 +209,13 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
}
}
#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4)
#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7)
#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) // MXFP4
#define MMQ_MMA_TILE_X_K_NVFP4 (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) // NVFP4
#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7)
static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
@ -220,6 +224,8 @@ static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding.");
static_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, "Wrong tile size for MXFP4");
static_assert(MMQ_MMA_TILE_X_K_NVFP4 % 8 == 4, "Wrong padding.");
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
switch (type) {
@ -230,6 +236,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
// tile sizes are the same for Q8_1 and FP4 for blackwell
case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_NVFP4;
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
@ -826,6 +833,65 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
}
}
template <int mmq_y, bool need_check>
static __device__ __forceinline__ void load_tiles_nvfp4(const char * __restrict__ x,
int * __restrict__ x_tile,
const int kb0,
const int i_max,
const int stride) {
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_NVFP4, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
constexpr int threads_per_row = MMQ_ITER_K / QK_NVFP4;
constexpr int rows_per_warp = warp_size / threads_per_row;
const int kbx = threadIdx.x % threads_per_row;
const int row_in_warp = threadIdx.x / threads_per_row;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
if constexpr (need_check) {
i = min(i, i_max);
}
const block_nvfp4 * bxi = (const block_nvfp4 *) x + kb0 + i * stride + kbx;
const uint32_t * __restrict__ src_qs = reinterpret_cast<const uint32_t *>(bxi->qs);
const int kqs = 16 * kbx;
const int ksc = 4 * kbx;
#pragma unroll
for (int sub = 0; sub < QK_NVFP4 / QK_NVFP4_SUB; ++sub) {
const int2 q0 = get_int_from_table_16(src_qs[2 * sub + 0], kvalues_mxfp4);
const int2 q1 = get_int_from_table_16(src_qs[2 * sub + 1], kvalues_mxfp4);
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 0] = q0.x;
x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 1] = q1.x;
x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 2] = q0.y;
x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 3] = q1.y;
x_df[i * MMQ_MMA_TILE_X_K_NVFP4 + ksc + sub] = ggml_cuda_ue4m3_to_fp32(bxi->d[sub]);
#else
x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 0] = q0.x;
x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 1] = q1.x;
x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 2] = q0.y;
x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 3] = q1.y;
x_df[i * (2 * MMQ_TILE_NE_K * 2 / QI_NVFP4) + i / (QK_NVFP4_SUB / QI_NVFP4) + ksc + sub] = ggml_cuda_ue4m3_to_fp32(bxi->d[sub]);
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
}
}
template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
@ -1229,7 +1295,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
// Used for Q3_K, IQ2_S, and IQ2_XS
// Used for NVFP4, Q3_K, IQ2_S, and IQ2_XS
template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
@ -3261,6 +3327,14 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
};
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_NVFP4> {
static constexpr int vdr = VDR_NVFP4_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
};
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
@ -4069,6 +4143,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
extern DECL_MMQ_CASE(GGML_TYPE_MXFP4);
extern DECL_MMQ_CASE(GGML_TYPE_NVFP4);
extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);

View File

@ -235,30 +235,33 @@ static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna4(ggml_type
// Host function: returns the max batch size for the current arch+type at runtime.
int get_mmvq_mmid_max_batch(ggml_type type, int cc) {
// NVIDIA: Volta, Ada Lovelace, and Blackwell always use MMVQ for MUL_MAT_ID.
if (cc == GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_ADA_LOVELACE) {
return MMVQ_MAX_BATCH_SIZE;
}
if (cc >= GGML_CUDA_CC_TURING) {
return get_mmvq_mmid_max_batch_turing_plus(type);
}
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
if (cc == GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_ADA_LOVELACE) {
return MMVQ_MAX_BATCH_SIZE;
}
if (cc >= GGML_CUDA_CC_TURING) {
return get_mmvq_mmid_max_batch_turing_plus(type);
}
return get_mmvq_mmid_max_batch_pascal_older(type);
}
// AMD
if (GGML_CUDA_CC_IS_RDNA4(cc)) {
return get_mmvq_mmid_max_batch_rdna4(type);
}
if (GGML_CUDA_CC_IS_RDNA3(cc)) {
return get_mmvq_mmid_max_batch_rdna3(type);
}
if (GGML_CUDA_CC_IS_RDNA1(cc) || GGML_CUDA_CC_IS_RDNA2(cc)) {
return get_mmvq_mmid_max_batch_rdna1_rdna2(type);
}
if (GGML_CUDA_CC_IS_CDNA(cc)) {
return get_mmvq_mmid_max_batch_cdna(type);
}
if (GGML_CUDA_CC_IS_GCN(cc)) {
return get_mmvq_mmid_max_batch_gcn(type);
if (GGML_CUDA_CC_IS_AMD(cc)) {
if (GGML_CUDA_CC_IS_RDNA4(cc)) {
return get_mmvq_mmid_max_batch_rdna4(type);
}
if (GGML_CUDA_CC_IS_RDNA3(cc)) {
return get_mmvq_mmid_max_batch_rdna3(type);
}
if (GGML_CUDA_CC_IS_RDNA1(cc) || GGML_CUDA_CC_IS_RDNA2(cc)) {
return get_mmvq_mmid_max_batch_rdna1_rdna2(type);
}
if (GGML_CUDA_CC_IS_CDNA(cc)) {
return get_mmvq_mmid_max_batch_cdna(type);
}
if (GGML_CUDA_CC_IS_GCN(cc)) {
return get_mmvq_mmid_max_batch_gcn(type);
}
}
return MMVQ_MAX_BATCH_SIZE;
}

View File

@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8);
DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8);
DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8);
DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8);
DECL_FATTN_MMA_F16_CASE(512, 512, 1, 8);

View File

@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
DECL_FATTN_MMA_F16_CASE(512, 512, 16, 4);
DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);

View File

@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
DECL_FATTN_MMA_F16_CASE(512, 512, 2, 4);
DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);

View File

@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8);
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8);
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8);
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8);
DECL_FATTN_MMA_F16_CASE(512, 512, 2, 8);

View File

@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
DECL_FATTN_MMA_F16_CASE(512, 512, 4, 4);
DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);

View File

@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8);
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8);
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8);
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8);
DECL_FATTN_MMA_F16_CASE(512, 512, 4, 8);

View File

@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
DECL_FATTN_MMA_F16_CASE(512, 512, 8, 4);
DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);

View File

@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 8);
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 8);
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 8);
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 8);
DECL_FATTN_MMA_F16_CASE(512, 512, 8, 8);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-tile.cuh"
DECL_FATTN_TILE_CASE(512, 512);

View File

@ -3,7 +3,7 @@
from glob import glob
import os
HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576]
HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 512, 576]
TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"]
@ -35,7 +35,7 @@ TYPES_MMQ = [
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
"GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
"GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S",
"GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_MXFP4"
"GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_MXFP4", "GGML_TYPE_NVFP4"
]
SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
@ -83,6 +83,8 @@ for ncols in [8, 16, 32, 64]:
continue
if head_size_kq == 72:
continue
if head_size_kq == 512 and ncols2 not in (4, 8):
continue
if head_size_kq != 576 and ncols2 in (16, 32):
continue
if head_size_kq == 576 and ncols2 not in (4, 16, 32):

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmq.cuh"
DECL_MMQ_CASE(GGML_TYPE_NVFP4);

View File

@ -23,6 +23,7 @@
#include "ggml-impl.h"
#include "ggml-sycl.h"
#include "presets.hpp"
#include "type.hpp"
#include "sycl_hw.hpp"
namespace syclexp = sycl::ext::oneapi::experimental;
@ -965,4 +966,10 @@ static T block_reduce(T val, T * shared_vals, int block_size_template) {
return val;
}
static __dpct_inline__ float ggml_sycl_ue4m3_to_fp32(uint8_t x) {
const uint32_t bits = x * (x != 0x7F && x != 0xFF);
const __nv_fp8_e4m3 xf = *reinterpret_cast<const __nv_fp8_e4m3 *>(&bits);
return static_cast<float>(xf) / 2;
}
#endif // GGML_SYCL_COMMON_HPP

View File

@ -482,6 +482,18 @@ static void dequantize_row_mxfp4_sycl(const void * vx, dst_t * y, const int64_t
});
}
template <typename dst_t>
static void dequantize_row_nvfp4_sycl(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {
GGML_ASSERT(k % QK_NVFP4 == 0);
const int nb = k / QK_NVFP4;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_nvfp4(vx, y, k);
});
}
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static void dequantize_block_nc(const void * __restrict__ vx, dst_t * __restrict__ y,
const int64_t ne00, const int64_t ne01, const int64_t ne02,
@ -641,6 +653,8 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
return dequantize_row_iq4_nl_sycl;
case GGML_TYPE_MXFP4:
return dequantize_row_mxfp4_sycl;
case GGML_TYPE_NVFP4:
return dequantize_row_nvfp4_sycl;
case GGML_TYPE_F32:
return convert_unary_sycl<float>;
#ifdef GGML_SYCL_HAS_BF16
@ -648,6 +662,7 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
return convert_unary_sycl<sycl::ext::oneapi::bfloat16>;
#endif
default:
GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(type));
return nullptr;
}
}
@ -708,6 +723,8 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
return dequantize_row_iq4_nl_sycl;
case GGML_TYPE_MXFP4:
return dequantize_row_mxfp4_sycl;
case GGML_TYPE_NVFP4:
return dequantize_row_nvfp4_sycl;
case GGML_TYPE_F16:
return convert_unary_sycl<sycl::half>;
#ifdef GGML_SYCL_HAS_BF16
@ -715,6 +732,7 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
return convert_unary_sycl<sycl::ext::oneapi::bfloat16>;
#endif
default:
GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(type));
return nullptr;
}
}

View File

@ -838,4 +838,36 @@ static void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restr
}
}
template <typename dst_t>
static void dequantize_block_nvfp4(
const void * __restrict__ vx,
dst_t * __restrict__ yy,
const int64_t ne) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
const int64_t i = item_ct1.get_group(2);
const int tid = item_ct1.get_local_id(2);
const int64_t base = i * QK_NVFP4;
if (base >= ne) {
return;
}
const block_nvfp4 * x = (const block_nvfp4 *) vx;
const block_nvfp4 & xb = x[i];
const int sub = tid / (QK_NVFP4_SUB / 2);
const int j = tid % (QK_NVFP4_SUB / 2);
const float d = ggml_sycl_ue4m3_to_fp32(xb.d[sub]);
const uint8_t q = xb.qs[sub * (QK_NVFP4_SUB / 2) + j];
const int64_t y0 = base + sub * QK_NVFP4_SUB + j;
const int64_t y1 = y0 + QK_NVFP4_SUB / 2;
yy[y0] = ggml_sycl_cast<dst_t>(d * kvalues_mxfp4[q & 0x0F]);
yy[y1] = ggml_sycl_cast<dst_t>(d * kvalues_mxfp4[q >> 4]);
}
#endif // GGML_SYCL_DEQUANTIZE_HPP

View File

@ -613,6 +613,23 @@ static void mul_mat_vec_mxfp4_q8_1_sycl(const void * vx, const void * vy, float
}
}
static void mul_mat_vec_nvfp4_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows,
dpct::queue_ptr stream) {
GGML_ASSERT(ncols % QK_NVFP4 == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK_NVFP4, QI_NVFP4, block_nvfp4, VDR_NVFP4_Q8_1_MMVQ, vec_dot_nvfp4_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
});
}
}
static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
float *dst, const int ncols,
@ -1145,8 +1162,11 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
case GGML_TYPE_MXFP4:
mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
case GGML_TYPE_NVFP4:
mul_mat_vec_nvfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
default:
GGML_ABORT("fatal error");
GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(src0->type));
}
}
GGML_UNUSED(src1);

112
ggml/src/ggml-sycl/type.hpp Normal file
View File

@ -0,0 +1,112 @@
#pragma once
#include <sycl/sycl.hpp>
#include <cstdint>
#include <limits>
inline uint8_t float_to_e4m3(float f)
{
if (sycl::isnan(f)) {
return 0x7F; // Canonical NaN (positive)
}
uint32_t bits = sycl::bit_cast<uint32_t>(f);
uint32_t sign = (bits >> 31) & 0x1u;
uint32_t exp = (bits >> 23) & 0xFFu;
uint32_t mant = bits & 0x7FFFFFu;
// Zero
if (exp == 0 && mant == 0) {
return static_cast<uint8_t>(sign << 7);
}
// Extract biased exponent and mantissa for FP8
int e = static_cast<int>(exp) - 127; // true exponent (IEEE bias 127)
uint32_t m = mant;
// Handle very large values → NaN (NVIDIA behavior for E4M3)
if (e > 7) { // max exponent for E4M3 is 7 (biased 14)
return static_cast<uint8_t>((sign << 7) | 0x7F);
}
// Handle subnormals and normal numbers
if (e < -6) { // smallest normal exponent is -6
// Subnormal in FP8: shift mantissa right
int shift = -6 - e;
m = (m | 0x800000u) >> (shift + 1); // +1 because we lose the implicit 1 position
if (shift > 23) m = 0;
} else {
// Normal number: adjust exponent bias from 127 to 7
int new_exp = e + 7;
m = (m >> 20) & 0x7u; // take top 3 mantissa bits (after implicit 1)
m |= (static_cast<uint32_t>(new_exp) << 3);
}
// Round-to-nearest-even (simple guard + round bit)
// For better accuracy you can add sticky bit, but this is sufficient for most use cases
uint32_t round_bit = (mant >> 19) & 0x1u; // bit after the 3 mantissa bits
if (round_bit) {
m += 1;
// Carry into exponent if mantissa overflows
if ((m & 0x8u) != 0) {
m = (m & 0x7u) | ((m & 0x38u) << 1); // simple carry handling
// If exponent overflows after carry → NaN
if ((m >> 3) > 14) {
return static_cast<uint8_t>((sign << 7) | 0x7F);
}
}
}
uint8_t result = static_cast<uint8_t>((sign << 7) | (m & 0x7F));
return result;
}
inline float e4m3_to_float(uint8_t x)
{
if (x == 0) return 0.0f;
uint8_t sign = (x >> 7) & 0x1u;
uint8_t exp = (x >> 3) & 0xFu;
uint8_t mant = x & 0x7u;
// NaN (NVIDIA uses 0x7F / 0xFF as NaN)
if (exp == 0xF && mant != 0) {
return std::numeric_limits<float>::quiet_NaN();
}
if (exp == 0xF) { // 0x7F or 0xFF treated as NaN
return std::numeric_limits<float>::quiet_NaN();
}
float val;
if (exp == 0) {
// Subnormal
val = mant * (1.0f / 8.0f) * sycl::pow(2.0f, -6.0f);
} else {
// Normal: implicit leading 1 + bias 7
val = (1.0f + mant / 8.0f) * sycl::pow(2.0f, static_cast<float>(exp) - 7.0f);
}
return sign ? -val : val;
}
// The actual type definition
struct __nv_fp8_e4m3 {
uint8_t raw;
__nv_fp8_e4m3() = default;
explicit __nv_fp8_e4m3(float f) : raw(float_to_e4m3(f)) {}
explicit __nv_fp8_e4m3(sycl::half h) : raw(float_to_e4m3(static_cast<float>(h))) {}
operator float() const { return e4m3_to_float(raw); }
operator sycl::half() const { return static_cast<sycl::half>(static_cast<float>(*this)); }
// Allow direct access for vector loads/stores
operator uint8_t&() { return raw; }
operator uint8_t() const { return raw; }
};
using __nv_fp8x2_e4m3 = sycl::vec<__nv_fp8_e4m3, 2>;
using __nv_fp8x4_e4m3 = sycl::vec<__nv_fp8_e4m3, 4>;

View File

@ -15,6 +15,7 @@
#include "dpct/helper.hpp"
#include "ggml.h"
#include "type.hpp"
#include "quants.hpp"
typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1,
@ -31,6 +32,18 @@ static __dpct_inline__ int get_int_b1(const void * x, const int & i32) {
return x32;
}
static __dpct_inline__ int get_int_b2(const void * x, const int & i32) {
const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
int x32 = x16[2*i32 + 0] << 0;
x32 |= x16[2*i32 + 1] << 16;
return x32;
}
static __dpct_inline__ int get_int_b4(const void * x, const int & i32) {
return ((const int *) x)[i32]; // assume at least 4 byte alignment
}
static __dpct_inline__ int get_int_from_int8(const int8_t* x8, const int& i32) {
const uint16_t* x16 =
@ -755,6 +768,35 @@ static __dpct_inline__ float vec_dot_mxfp4_q8_1(const void * __restrict__ vbq,
return d * sumi;
}
#define VDR_NVFP4_Q8_1_MMVQ 4
#define VDR_NVFP4_Q8_1_MMQ 8
static __dpct_inline__ float vec_dot_nvfp4_q8_1(const void * __restrict__ vbq,
const block_q8_1 * __restrict__ bq8_1,
const int32_t & iqs) {
const block_nvfp4 * bq4 = (const block_nvfp4 *) vbq;
float sum = 0.0f;
#pragma unroll
for (int i = 0; i < VDR_NVFP4_Q8_1_MMVQ/2; i++) {
const int32_t iqs0 = iqs + 2*i;
const int32_t iqs1 = iqs0 + 1;
const int32_t is = iqs0 >> 1;
const sycl::int2 v0 = get_int_from_table_16(get_int_b4(bq4->qs, iqs0), kvalues_mxfp4);
const sycl::int2 v1 = get_int_from_table_16(get_int_b4(bq4->qs, iqs1), kvalues_mxfp4);
const block_q8_1 * bq8 = bq8_1 + (is >> 1);
const int32_t i8 = ((is & 1) << 2);
int sumi = ggml_sycl_dp4a(v0.x(), get_int_b4(bq8->qs, i8 + 0), 0);
sumi = ggml_sycl_dp4a(v0.y(), get_int_b4(bq8->qs, i8 + 2), sumi);
sumi = ggml_sycl_dp4a(v1.x(), get_int_b4(bq8->qs, i8 + 1), sumi);
sumi = ggml_sycl_dp4a(v1.y(), get_int_b4(bq8->qs, i8 + 3), sumi);
const float d = ggml_sycl_ue4m3_to_fp32(bq4->d[is]) * (bq8->ds)[0];
sum += d * float(sumi);
}
return sum;
}
static __dpct_inline__ float
vec_dot_q5_0_q8_1(const void *__restrict__ vbq,

View File

@ -535,6 +535,95 @@ struct ggml_webgpu_mul_mat_shader_decisions {
uint32_t mul_mat_wg_size;
};
/** Cpy **/
struct ggml_webgpu_cpy_pipeline_key {
ggml_type src_type;
ggml_type dst_type;
bool operator==(const ggml_webgpu_cpy_pipeline_key & other) const {
return src_type == other.src_type && dst_type == other.dst_type;
}
};
struct ggml_webgpu_cpy_pipeline_key_hash {
size_t operator()(const ggml_webgpu_cpy_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.src_type);
ggml_webgpu_hash_combine(seed, key.dst_type);
return seed;
}
};
/** Glu **/
struct ggml_webgpu_glu_pipeline_key {
ggml_glu_op glu_op;
ggml_type type;
bool split;
bool operator==(const ggml_webgpu_glu_pipeline_key & other) const {
return glu_op == other.glu_op && type == other.type && split == other.split;
}
};
struct ggml_webgpu_glu_pipeline_key_hash {
size_t operator()(const ggml_webgpu_glu_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.glu_op);
ggml_webgpu_hash_combine(seed, key.type);
ggml_webgpu_hash_combine(seed, key.split);
return seed;
}
};
/** Rope **/
struct ggml_webgpu_rope_pipeline_key {
ggml_type type;
bool inplace;
bool has_ff;
bool operator==(const ggml_webgpu_rope_pipeline_key & other) const {
return type == other.type && inplace == other.inplace && has_ff == other.has_ff;
}
};
struct ggml_webgpu_rope_pipeline_key_hash {
size_t operator()(const ggml_webgpu_rope_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.type);
ggml_webgpu_hash_combine(seed, key.inplace);
ggml_webgpu_hash_combine(seed, key.has_ff);
return seed;
}
};
/** SoftMax **/
struct ggml_webgpu_soft_max_pipeline_key {
ggml_type mask_type;
bool has_mask;
bool has_sink;
bool inplace;
bool operator==(const ggml_webgpu_soft_max_pipeline_key & other) const {
return mask_type == other.mask_type && has_mask == other.has_mask && has_sink == other.has_sink &&
inplace == other.inplace;
}
};
struct ggml_webgpu_soft_max_pipeline_key_hash {
size_t operator()(const ggml_webgpu_soft_max_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.mask_type);
ggml_webgpu_hash_combine(seed, key.has_mask);
ggml_webgpu_hash_combine(seed, key.has_sink);
ggml_webgpu_hash_combine(seed, key.inplace);
return seed;
}
};
class ggml_webgpu_shader_lib {
wgpu::Device device;
pre_wgsl::Preprocessor preprocessor;
@ -582,6 +671,12 @@ class ggml_webgpu_shader_lib {
std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
set_rows_pipelines;
std::unordered_map<ggml_webgpu_set_pipeline_key, webgpu_pipeline, ggml_webgpu_set_pipeline_key_hash> set_pipelines;
std::unordered_map<ggml_webgpu_cpy_pipeline_key, webgpu_pipeline, ggml_webgpu_cpy_pipeline_key_hash> cpy_pipelines;
std::unordered_map<ggml_webgpu_glu_pipeline_key, webgpu_pipeline, ggml_webgpu_glu_pipeline_key_hash> glu_pipelines;
std::unordered_map<ggml_webgpu_rope_pipeline_key, webgpu_pipeline, ggml_webgpu_rope_pipeline_key_hash>
rope_pipelines;
std::unordered_map<ggml_webgpu_soft_max_pipeline_key, webgpu_pipeline, ggml_webgpu_soft_max_pipeline_key_hash>
soft_max_pipelines;
public:
ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
@ -1124,9 +1219,8 @@ class ggml_webgpu_shader_lib {
defines.push_back("BYTE_HELPERS");
defines.push_back("MUL_ACC_" + type_upper);
// For fast path we always dequantize from f16 inside the shader
defines.push_back("SRC0_INNER_TYPE=f16");
defines.push_back("U32_DEQUANT_HELPERS");
defines.push_back("SRC0_INNER_TYPE=u32");
break;
}
}
@ -1239,9 +1333,8 @@ class ggml_webgpu_shader_lib {
defines.push_back("MUL_ACC_" + type_upper);
defines.push_back("INIT_SRC0_SHMEM_" + type_upper);
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
// Use f16 inside the shader for quantized types
defines.push_back("SRC0_INNER_TYPE=f16");
defines.push_back("U32_DEQUANT_HELPERS");
defines.push_back("SRC0_INNER_TYPE=u32");
variant += std::string("_") + src0_name;
break;
@ -1679,6 +1772,236 @@ class ggml_webgpu_shader_lib {
return flash_attn_pipelines[key];
}
webgpu_pipeline get_cpy_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_cpy_pipeline_key key = {
.src_type = context.src0->type,
.dst_type = context.dst->type,
};
auto it = cpy_pipelines.find(key);
if (it != cpy_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "cpy";
switch (key.src_type) {
case GGML_TYPE_F32:
defines.push_back("SRC_F32");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("SRC_F16");
variant += "_f16";
break;
default:
GGML_ABORT("Unsupported src type for cpy shader");
}
switch (key.dst_type) {
case GGML_TYPE_F32:
defines.push_back("DST_F32");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("DST_F16");
variant += "_f16";
break;
case GGML_TYPE_I32:
defines.push_back("DST_I32");
variant += "_i32";
break;
default:
GGML_ABORT("Unsupported dst type for cpy shader");
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_cpy, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
cpy_pipelines[key] = pipeline;
return cpy_pipelines[key];
}
webgpu_pipeline get_glu_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_glu_pipeline_key key = {
.glu_op = ggml_get_glu_op(context.dst),
.type = context.dst->type,
.split = (context.src1 != nullptr),
};
auto it = glu_pipelines.find(key);
if (it != glu_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "glu";
switch (key.glu_op) {
case GGML_GLU_OP_REGLU:
defines.push_back("OP_REGLU");
variant += "_reglu";
break;
case GGML_GLU_OP_GEGLU:
defines.push_back("OP_GEGLU");
variant += "_geglu";
break;
case GGML_GLU_OP_SWIGLU:
defines.push_back("OP_SWIGLU");
variant += "_swiglu";
break;
case GGML_GLU_OP_SWIGLU_OAI:
defines.push_back("OP_SWIGLU_OAI");
variant += "_swiglu_oai";
break;
case GGML_GLU_OP_GEGLU_ERF:
defines.push_back("OP_GEGLU_ERF");
variant += "_geglu_erf";
break;
case GGML_GLU_OP_GEGLU_QUICK:
defines.push_back("OP_GEGLU_QUICK");
variant += "_geglu_quick";
break;
default:
GGML_ABORT("Unsupported GLU op");
}
switch (key.type) {
case GGML_TYPE_F32:
defines.push_back("TYPE_F32");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("TYPE_F16");
variant += "_f16";
break;
default:
GGML_ABORT("Unsupported type for GLU shader");
}
if (key.split) {
variant += "_split";
} else {
defines.push_back("NO_SPLIT");
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_glu, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
glu_pipelines[key] = pipeline;
return glu_pipelines[key];
}
webgpu_pipeline get_rope_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_rope_pipeline_key key = {
.type = context.dst->type,
.inplace = context.inplace,
.has_ff = (context.src2 != nullptr),
};
auto it = rope_pipelines.find(key);
if (it != rope_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "rope";
switch (key.type) {
case GGML_TYPE_F32:
defines.push_back("TYPE_F32");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("TYPE_F16");
variant += "_f16";
break;
default:
GGML_ABORT("Unsupported type for ROPE shader");
}
if (key.inplace) {
defines.push_back("INPLACE");
variant += "_inplace";
}
if (key.has_ff) {
defines.push_back("FF_FUNC");
variant += "_ff";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_rope, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
rope_pipelines[key] = pipeline;
return rope_pipelines[key];
}
webgpu_pipeline get_soft_max_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_soft_max_pipeline_key key = {
.mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32,
.has_mask = (context.src1 != nullptr),
.has_sink = (context.src2 != nullptr),
.inplace = context.inplace,
};
auto it = soft_max_pipelines.find(key);
if (it != soft_max_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "soft_max";
if (key.has_mask) {
defines.push_back("HAS_MASK");
switch (key.mask_type) {
case GGML_TYPE_F32:
defines.push_back("MASK_F32");
variant += "_mask_f32";
break;
case GGML_TYPE_F16:
defines.push_back("MASK_F16");
variant += "_mask_f16";
break;
default:
GGML_ABORT("Unsupported type for SOFT_MAX shader");
}
}
if (key.has_sink) {
defines.push_back("HAS_SINK");
variant += "_sink";
}
if (key.inplace) {
defines.push_back("INPLACE");
variant += "_inplace";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_soft_max, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
soft_max_pipelines[key] = pipeline;
return soft_max_pipelines[key];
}
private:
static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
std::string shader_code,

View File

@ -83,7 +83,7 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim
#define WEBGPU_NUM_PARAM_BUFS 96u
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 32u
#define WEBGPU_WAIT_ANY_TIMEOUT_MS 0
#define WEBGPU_WAIT_ANY_TIMEOUT_MS 100
// Maximum number of in-flight submissions per-thread, to avoid exhausting the
// parameter buffer pool
#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD (WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE)
@ -171,6 +171,7 @@ struct webgpu_buf_pool {
// Try growing the pool if no free buffers
if (free.empty() && cur_pool_size < max_pool_size && should_grow) {
cur_pool_size++;
lock.unlock(); // avoid deadlock between this lock and Dawn's internal locks when buffers are freed in callbacks
wgpu::Buffer dev_buf;
ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
@ -364,13 +365,6 @@ struct webgpu_context_struct {
wgpu::Buffer set_rows_dev_error_buf;
wgpu::Buffer set_rows_host_error_buf;
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> glu_pipelines; // glu_op, type, split
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> soft_max_pipelines; // mask_type, has_sink, inplace
size_t memset_bytes_per_thread;
};
@ -514,7 +508,7 @@ static void ggml_backend_webgpu_wait(webgpu_global_context & ctx,
bool blocking_wait = block || subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD;
while (blocking_wait) {
auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, 0);
auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, WEBGPU_WAIT_ANY_TIMEOUT_MS * 1e6);
if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) {
#ifdef GGML_WEBGPU_GPU_PROFILE
ggml_backend_webgpu_wait_profile_futures(ctx, subs[0].profile_futures, true);
@ -735,7 +729,6 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x);
std::vector<webgpu_command> commands = { command };
std::vector<webgpu_submission> sub = { ggml_backend_webgpu_submit(ctx, commands, ctx->memset_buf_pool) };
ggml_backend_webgpu_wait(ctx, sub);
}
/** End WebGPU Actions */
@ -849,6 +842,16 @@ static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0
}
static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_cpy_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t ne = (uint32_t) ggml_nelements(dst);
std::vector<uint32_t> params = {
@ -875,9 +878,8 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
};
uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->cpy_pipelines[src->type][dst->type],
params, entries, wg_x);
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
static webgpu_command ggml_webgpu_set(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
@ -1914,6 +1916,19 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
ggml_tensor * src1,
ggml_tensor * src2,
ggml_tensor * dst) {
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src0,
.src1 = src1,
.src2 = src2,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
.inplace = ggml_webgpu_tensor_equal(src0, dst),
};
webgpu_pipeline pipeline = ctx->shader_lib->get_rope_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
const int inplace = ggml_webgpu_tensor_equal(src0, dst);
const int has_freq_factor = (src2 != nullptr);
@ -1996,12 +2011,22 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
}
webgpu_pipeline pipeline = ctx->rope_pipelines[dst->type][has_freq_factor][inplace];
uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src0,
.src1 = src1,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_glu_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
const int split = (src1 != nullptr);
std::vector<uint32_t> params = {
@ -2048,8 +2073,7 @@ static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0,
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
webgpu_pipeline pipeline = ctx->glu_pipelines[ggml_get_glu_op(dst)][dst->type][split];
uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
@ -2109,9 +2133,20 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
ggml_tensor * src1,
ggml_tensor * src2,
ggml_tensor * dst) {
const int inplace = ggml_webgpu_tensor_equal(src0, dst);
const int mask_type = (src1 != nullptr) ? src1->type : 2; // use 2 for no mask here
const int has_sink = (src2 != nullptr);
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src0,
.src1 = src1,
.src2 = src2,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
.inplace = ggml_webgpu_tensor_equal(src0, dst),
};
webgpu_pipeline pipeline = ctx->shader_lib->get_soft_max_pipeline(shader_lib_ctx);
const int inplace = ggml_webgpu_tensor_equal(src0, dst);
const int has_mask = (src1 != nullptr);
const int has_sink = (src2 != nullptr);
float max_bias;
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2])));
@ -2120,15 +2155,15 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
mask_type < 2 ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
has_sink ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
mask_type < 2 ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0,
mask_type < 2 ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0,
mask_type < 2 ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0,
has_mask ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0,
has_mask ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0,
has_mask ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0,
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
@ -2136,8 +2171,8 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
(uint32_t) src0->ne[0],
(uint32_t) src0->ne[1],
(uint32_t) src0->ne[2],
mask_type < 2 ? (uint32_t) src1->ne[2] : 0,
mask_type < 2 ? (uint32_t) src1->ne[3] : 0,
has_mask ? (uint32_t) src1->ne[2] : 0,
has_mask ? (uint32_t) src1->ne[3] : 0,
*(uint32_t *) dst->op_params, // scale
*(uint32_t *) &max_bias,
*(uint32_t *) &n_head_log2,
@ -2152,7 +2187,7 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
.size = ggml_webgpu_tensor_binding_size(ctx, src0) }
};
uint32_t binding_num = 1;
if (mask_type < 2) {
if (has_mask) {
entries.push_back({ .binding = binding_num,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
@ -2173,9 +2208,7 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
}
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool,
ctx->soft_max_pipelines[mask_type][has_sink][inplace], params, entries,
ggml_nrows(dst));
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, ggml_nrows(dst));
}
static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
@ -2661,17 +2694,6 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
// memset the remaining bytes
ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32,
total_offset + (size - remaining_size), remaining_size);
} else {
// wait for WriteBuffer to complete
buf_ctx->global_ctx->instance.WaitAny(buf_ctx->global_ctx->queue.OnSubmittedWorkDone(
wgpu::CallbackMode::AllowSpontaneous,
[](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
if (status != wgpu::QueueWorkDoneStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n",
std::string(message).c_str());
}
}),
UINT64_MAX);
}
WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, buf_ctx->global_ctx);
}
@ -2885,139 +2907,6 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) {
ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants);
}
static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F32] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants);
webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_I32] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_i32, "cpy_f32_i32", constants);
webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F16] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants);
webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F32] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f32, "cpy_f16_f32", constants);
webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F16] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
}
static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32, "rope_f32", constants);
webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_rope_f32_inplace, "rope_f32_inplace", constants);
webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff, "rope_f32_ff", constants);
webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants);
webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16, "rope_f16", constants);
webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_rope_f16_inplace, "rope_f16_inplace", constants);
webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff, "rope_f16_ff", constants);
webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants);
}
static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
// REGLU
webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32, "reglu_f32", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16, "reglu_f16", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32_split, "reglu_f32_split", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16_split, "reglu_f16_split", constants);
// GEGLU
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32, "geglu_f32", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16, "geglu_f16", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32_split, "geglu_f32_split", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16_split, "geglu_f16_split", constants);
// SWIGLU
webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f32, "swiglu_f32", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f16, "swiglu_f16", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_swiglu_f32_split, "swiglu_f32_split", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_swiglu_f16_split, "swiglu_f16_split", constants);
// SWIGLU_OAI
webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants);
// GEGLU_ERF
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32, "geglu_erf_f32", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16, "geglu_erf_f16", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants);
// GEGLU_QUICK
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32, "geglu_quick_f32", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16, "geglu_quick_f16", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants);
}
static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
// f32 (no mask)
webgpu_ctx->soft_max_pipelines[2][0][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32, "soft_max_f32", constants);
webgpu_ctx->soft_max_pipelines[2][0][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_inplace, "soft_max_f32_inplace", constants);
webgpu_ctx->soft_max_pipelines[2][1][0] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink, "soft_max_f32_sink", constants);
webgpu_ctx->soft_max_pipelines[2][1][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants);
// f32 mask (mask_type = 0)
webgpu_ctx->soft_max_pipelines[0][0][0] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32, "soft_max_f32_mask_f32", constants);
webgpu_ctx->soft_max_pipelines[0][0][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants);
webgpu_ctx->soft_max_pipelines[0][1][0] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants);
webgpu_ctx->soft_max_pipelines[0][1][1] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace,
"soft_max_f32_mask_f32_sink_inplace", constants);
// f16 mask (mask_type = 1)
webgpu_ctx->soft_max_pipelines[1][0][0] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16, "soft_max_f32_mask_f16", constants);
webgpu_ctx->soft_max_pipelines[1][0][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants);
webgpu_ctx->soft_max_pipelines[1][1][0] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants);
webgpu_ctx->soft_max_pipelines[1][1][1] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace,
"soft_max_f32_mask_f16_sink_inplace", constants);
}
static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
wgpu::RequestAdapterOptions options = {};
@ -3183,10 +3072,6 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf");
ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
ggml_webgpu_init_rope_pipeline(webgpu_ctx);
ggml_webgpu_init_glu_pipeline(webgpu_ctx);
ggml_webgpu_init_soft_max_pipeline(webgpu_ctx);
#ifdef GGML_WEBGPU_DEBUG
// Initialize debug buffers
ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_host_buf,

View File

@ -8,6 +8,30 @@ fn get_byte_i32(value: u32, index: u32) -> i32 {
}
#endif
#ifdef U32_DEQUANT_HELPERS
fn load_src0_u16_at(byte_offset: u32) -> u32 {
let word = src0[byte_offset / 4u];
let shift = (byte_offset & 2u) * 8u;
return (word >> shift) & 0xFFFFu;
}
fn load_src0_u32_at(byte_offset: u32) -> u32 {
let word_idx = byte_offset / 4u;
let shift = (byte_offset & 3u) * 8u;
let lo = src0[word_idx];
if (shift == 0u) {
return lo;
}
let hi = src0[word_idx + 1u];
return (lo >> shift) | (hi << (32u - shift));
}
fn load_src0_f16_at(byte_offset: u32) -> f16 {
let packed = unpack2x16float(load_src0_u16_at(byte_offset));
return f16(packed[0]);
}
#endif
#ifdef Q4_0_T
struct q4_0 {
d: f16,

View File

@ -1,66 +1,41 @@
#define(VARIANTS)
[
{
"REPLS": {
"SRC_TYPE": "f32",
"DST_TYPE": "f32"
}
},
{
"REPLS": {
"SRC_TYPE": "f32",
"DST_TYPE": "i32"
}
},
{
"REPLS": {
"SRC_TYPE": "f32",
"DST_TYPE": "f16"
}
},
{
"REPLS": {
"SRC_TYPE": "f16",
"DST_TYPE": "f16"
}
},
{
"REPLS": {
"SRC_TYPE": "f16",
"DST_TYPE": "f32"
}
}
]
#end(VARIANTS)
#define(SHADER)
enable f16;
#ifdef SRC_F32
#define SRC_TYPE f32
#elif defined(SRC_F16)
#define SRC_TYPE f16
#endif
#ifdef DST_F32
#define DST_TYPE f32
#elif defined(DST_F16)
#define DST_TYPE f16
#elif defined(DST_I32)
#define DST_TYPE i32
#endif
@group(0) @binding(0)
var<storage, read_write> src: array<{{SRC_TYPE}}>;
var<storage, read_write> src: array<SRC_TYPE>;
@group(0) @binding(1)
var<storage, read_write> dst: array<{{DST_TYPE}}>;
var<storage, read_write> dst: array<DST_TYPE>;
struct Params {
ne: u32, // total number of elements
offset_src: u32, // in elements
offset_dst: u32, // in elements
struct Params{
ne: u32,
offset_src: u32,
offset_dst: u32,
// Strides (in elements) may be permuted
stride_src0: u32,
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,
stride_dst0: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
// Logical shapes
src_ne0: u32,
src_ne1: u32,
src_ne2: u32,
@ -73,8 +48,7 @@ struct Params {
@group(0) @binding(2)
var<uniform> params: Params;
override wg_size: u32;
@compute @workgroup_size(wg_size)
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
return;
@ -102,6 +76,6 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 +
j2 * params.stride_dst2 + j3 * params.stride_dst3;
dst[params.offset_dst + dst_idx] = {{DST_TYPE}}((src[params.offset_src + src_idx]));
dst[params.offset_dst + dst_idx] = DST_TYPE((src[params.offset_src + src_idx]));
}
#end(SHADER)

View File

@ -1,41 +1,8 @@
import os
import re
import ast
import argparse
def extract_block(text, name):
pattern = rf'#define\({name}\)\s*(.*?)#end\({name}\)'
match = re.search(pattern, text, re.DOTALL)
if not match:
raise ValueError(f"Missing block: {name}")
return match.group(1).strip()
def parse_decls(decls_text):
decls = {}
for name, code in re.findall(r'#decl\((.*?)\)\s*(.*?)#enddecl\(\1\)', decls_text, re.DOTALL):
decls[name.strip()] = code.strip()
return decls
def replace_repl_placeholders(variant, template_map):
for repl, code in variant["REPLS"].items():
for key, val in template_map.items():
# Match "key" and avoid matching subsequences using by using \b
code = re.sub(rf'\b{re.escape(str(key))}\b', str(val), code)
variant["REPLS"][repl] = code
return variant
def replace_placeholders(shader_text, replacements):
for key, val in replacements.items():
# Match {{KEY}} literally, where KEY is escaped
pattern = r'{{\s*' + re.escape(key) + r'\s*}}'
shader_text = re.sub(pattern, str(val), shader_text)
return shader_text
def expand_includes(shader, input_dir):
"""
Replace #include "file" lines in the text with the contents of that file.
@ -98,84 +65,24 @@ def write_shader(shader_name, shader_code, output_dir, outfile, input_dir):
outfile.write(f'const char* wgsl_{shader_name} = wgsl_{shader_name}_str().c_str();\n\n')
def generate_variants(fname, input_dir, output_dir, outfile):
shader_path = os.path.join(input_dir, fname)
shader_base_name = fname.split(".")[0]
with open(shader_path, "r", encoding="utf-8") as f:
text = f.read()
try:
variants = ast.literal_eval(extract_block(text, "VARIANTS"))
except ValueError:
write_shader(shader_base_name, text, output_dir, outfile, input_dir)
else:
try:
decls_map = parse_decls(extract_block(text, "DECLS"))
except ValueError:
decls_map = {}
try:
templates_map = ast.literal_eval(extract_block(text, "REPL_TEMPLATES"))
except ValueError:
templates_map = {}
for fname in sorted(os.listdir(input_dir)):
if fname.endswith(".tmpl"):
tmpl_path = os.path.join(input_dir, fname)
with open(tmpl_path, "r", encoding="utf-8") as f_tmpl:
decls = f_tmpl.read()
decls_map.update(parse_decls(decls))
shader_template = extract_block(text, "SHADER")
for variant in variants:
if "DECLS" in variant:
decls = variant["DECLS"]
else:
decls = []
decls_code = ""
for key in decls:
if key not in decls_map:
raise ValueError(f"DECLS key '{key}' not found.")
decls_code += decls_map[key] + "\n\n"
final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template)
if "REPLS" in variant:
variant = replace_repl_placeholders(variant, templates_map)
final_shader = replace_placeholders(final_shader, variant["REPLS"])
# second run to expand placeholders in repl_template
final_shader = replace_placeholders(final_shader, variant["REPLS"])
final_shader = expand_includes(final_shader, input_dir)
if "SHADER_NAME" in variant:
output_name = variant["SHADER_NAME"]
elif "SHADER_SUFFIX" in variant:
output_name = f"{shader_base_name}_" + variant["SHADER_SUFFIX"]
elif "REPLS" in variant and "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]:
output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]])
elif "REPLS" in variant and "SRC_TYPE" in variant["REPLS"] and "DST_TYPE" in variant["REPLS"]:
output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC_TYPE"], variant["REPLS"]["DST_TYPE"]])
elif "REPLS" in variant and "TYPE" in variant["REPLS"]:
output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"]
else:
output_name = shader_base_name
write_shader(output_name, final_shader, output_dir, outfile, input_dir)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input_dir", required=True)
parser.add_argument("--output_file", required=True)
parser.add_argument("--output_dir")
args = parser.parse_args()
if args.output_dir:
os.makedirs(args.output_dir, exist_ok=True)
with open(args.output_file, "w", encoding="utf-8") as out:
out.write("// Auto-generated shader embedding\n")
out.write("#include <string>\n\n")
for fname in sorted(os.listdir(args.input_dir)):
if fname.endswith(".wgsl"):
generate_variants(fname, args.input_dir, args.output_dir, out)
shader_path = os.path.join(args.input_dir, fname)
shader_name = fname.replace(".wgsl", "")
with open(shader_path, "r", encoding="utf-8") as f:
shader_code = f.read()
write_shader(shader_name, shader_code, None, out, args.input_dir)
if __name__ == "__main__":

View File

@ -6,6 +6,8 @@ enable chromium_experimental_subgroup_matrix;
#ifdef KV_F32
#define KV_TYPE f32
#elif defined(KV_Q4_0) || defined(KV_Q8_0)
#define KV_TYPE u32
#else
#define KV_TYPE f16
#endif
@ -37,11 +39,13 @@ enable chromium_experimental_subgroup_matrix;
#define NQ 16
// Q4_0 has 32 elements, 1 f16 for scale, 8 f16 for 4-bit weights
#define F16_PER_BLOCK 9
#define BLOCK_SIZE_BYTES 18u
#define WEIGHTS_PER_F16 4
#elif defined(KV_Q8_0)
#define NQ 8
// Q8_0 has 32 elements, 1 f16 for scale, 16 f16 for 8-bit weights
#define F16_PER_BLOCK 17
#define BLOCK_SIZE_BYTES 34u
#define WEIGHTS_PER_F16 2
#endif
#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16)
@ -55,6 +59,47 @@ fn get_byte_i32(value: u32, index: u32) -> i32 {
return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
}
#if defined(KV_Q4_0) || defined(KV_Q8_0)
fn load_k_u16_at(byte_offset: u32) -> u32 {
let word = K[byte_offset / 4u];
let shift = (byte_offset & 2u) * 8u;
return (word >> shift) & 0xFFFFu;
}
fn load_k_u32_at(byte_offset: u32) -> u32 {
let word_idx = byte_offset / 4u;
let shift = (byte_offset & 3u) * 8u;
let lo = K[word_idx];
if (shift == 0u) {
return lo;
}
let hi = K[word_idx + 1u];
return (lo >> shift) | (hi << (32u - shift));
}
fn load_v_u16_at(byte_offset: u32) -> u32 {
let word = V[byte_offset / 4u];
let shift = (byte_offset & 2u) * 8u;
return (word >> shift) & 0xFFFFu;
}
fn load_v_u32_at(byte_offset: u32) -> u32 {
let word_idx = byte_offset / 4u;
let shift = (byte_offset & 3u) * 8u;
let lo = V[word_idx];
if (shift == 0u) {
return lo;
}
let hi = V[word_idx + 1u];
return (lo >> shift) | (hi << (32u - shift));
}
fn f16_from_u16(bits: u32) -> f16 {
let packed = unpack2x16float(bits);
return f16(packed[0]);
}
#endif
struct Params {
offset_q: u32,
offset_k: u32,
@ -254,12 +299,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
if (global_k_row < params.seq_len_kv) {
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
let base_idx = global_block_idx * F16_PER_BLOCK;
let d = K[base_idx]; // scale
let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
let d = f16_from_u16(load_k_u16_at(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = K[base_idx + 1u + block_offset + j];
let q_1 = K[base_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_k_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
@ -282,12 +326,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
if (global_k_row < params.seq_len_kv) {
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
let base_idx = global_block_idx * F16_PER_BLOCK;
let d = K[base_idx]; // scale
let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
let d = f16_from_u16(load_k_u16_at(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = K[base_idx + 1u + block_offset + j];
let q_1 = K[base_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_k_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f16(q_byte) * d;
@ -459,12 +502,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
if (global_v_row < params.seq_len_kv) {
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
let base_idx = global_block_idx * F16_PER_BLOCK;
let d = V[base_idx]; // scale
let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
let d = f16_from_u16(load_v_u16_at(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = V[base_idx + 1u + block_offset + j];
let q_1 = V[base_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_v_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
@ -487,12 +529,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
if (global_v_row < params.seq_len_kv) {
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
let base_idx = global_block_idx * F16_PER_BLOCK;
let d = V[base_idx]; // scale
let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
let d = f16_from_u16(load_v_u16_at(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = V[base_idx + 1u + block_offset + j];
let q_1 = V[base_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_v_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f16(q_byte) * d;

View File

@ -1,323 +0,0 @@
#define(VARIANTS)
[
{
"SHADER_NAME": "reglu_f32",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["NO_SPLIT", "REGLU"]
},
{
"SHADER_NAME": "reglu_f32_split",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["SPLIT", "REGLU"]
},
{
"SHADER_NAME": "reglu_f16",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["NO_SPLIT", "REGLU"]
},
{
"SHADER_NAME": "reglu_f16_split",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["SPLIT", "REGLU"]
},
{
"SHADER_NAME": "geglu_f32",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["NO_SPLIT", "GEGLU"]
},
{
"SHADER_NAME": "geglu_f32_split",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["SPLIT", "GEGLU"]
},
{
"SHADER_NAME": "geglu_f16",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["NO_SPLIT", "GEGLU"]
},
{
"SHADER_NAME": "geglu_f16_split",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["SPLIT", "GEGLU"]
},
{
"SHADER_NAME": "swiglu_f32",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["NO_SPLIT", "SWIGLU"]
},
{
"SHADER_NAME": "swiglu_f32_split",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["SPLIT", "SWIGLU"]
},
{
"SHADER_NAME": "swiglu_f16",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["NO_SPLIT", "SWIGLU"]
},
{
"SHADER_NAME": "swiglu_f16_split",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["SPLIT", "SWIGLU"]
},
{
"SHADER_NAME": "swiglu_oai_f32",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["NO_SPLIT", "SWIGLU_OAI"]
},
{
"SHADER_NAME": "swiglu_oai_f32_split",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["SPLIT", "SWIGLU_OAI"]
},
{
"SHADER_NAME": "geglu_erf_f32",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["NO_SPLIT", "GEGLU_ERF"]
},
{
"SHADER_NAME": "geglu_erf_f32_split",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["SPLIT", "GEGLU_ERF"]
},
{
"SHADER_NAME": "geglu_erf_f16",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["NO_SPLIT", "GEGLU_ERF"]
},
{
"SHADER_NAME": "geglu_erf_f16_split",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["SPLIT", "GEGLU_ERF"]
},
{
"SHADER_NAME": "geglu_quick_f32",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["NO_SPLIT", "GEGLU_QUICK"]
},
{
"SHADER_NAME": "geglu_quick_f32_split",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["SPLIT", "GEGLU_QUICK"]
},
{
"SHADER_NAME": "geglu_quick_f16",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["NO_SPLIT", "GEGLU_QUICK"]
},
{
"SHADER_NAME": "geglu_quick_f16_split",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["SPLIT", "GEGLU_QUICK"]
},
]
#end(VARIANTS)
#define(DECLS)
#decl(REGLU)
fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
return max(a, 0) * b;
}
#enddecl(REGLU)
#decl(GEGLU)
const SQRT_2_OVER_PI: {{TYPE}} = 0.79788456080286535587989211986876;
const GELU_COEF_A: {{TYPE}} = 0.044715;
fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a);
return 0.5 * a * (2.0 - 2.0 / (exp(2 * val) + 1)) * b;
}
#enddecl(GEGLU)
#decl(SWIGLU)
fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
return a / (1.0 + exp(-a)) * b;
}
#enddecl(SWIGLU)
#decl(SWIGLU_OAI)
fn op(a: f32, b: f32) -> f32 {
let xi = min(a, params.limit);
let gi = max(min(b, params.limit), -params.limit);
var out_glu = xi / (1.0 + exp(-xi * params.alpha));
out_glu = out_glu * (1.0 + gi);
return out_glu;
}
#enddecl(SWIGLU_OAI)
#decl(GEGLU_ERF)
const p_erf: {{TYPE}} = 0.3275911;
const a1_erf: {{TYPE}} = 0.254829592;
const a2_erf: {{TYPE}} = -0.284496736;
const a3_erf: {{TYPE}} = 1.421413741;
const a4_erf: {{TYPE}} = -1.453152027;
const a5_erf: {{TYPE}} = 1.061405429;
const SQRT_2_INV: {{TYPE}} = 0.7071067811865476;
fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
let a_div_sqr2 = a * SQRT_2_INV;
let sign_x = sign(a_div_sqr2);
let x = abs(a_div_sqr2);
let t = 1.0 / (1.0 + p_erf * x);
let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x));
let erf_approx = sign_x * y;
return 0.5 * a * (1.0 + erf_approx) * b;
}
#enddecl(GEGLU_ERF)
#decl(GEGLU_QUICK)
const GELU_QUICK_COEF: {{TYPE}} = -1.702;
fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b;
}
#enddecl(GEGLU_QUICK)
#decl(NO_SPLIT)
@group(0) @binding(1)
var<storage, read_write> dst: array<{{TYPE}}>;
@group(0) @binding(2)
var<uniform> params: Params;
fn a_value(base: u32) -> {{TYPE}} {
let offset: u32 = select(0, params.ne0, params.swapped != 0);
return src0[base + offset];
}
fn b_value(base: u32) -> {{TYPE}} {
let offset: u32 = select(params.ne0, 0, params.swapped != 0);
return src0[base + offset];
}
#enddecl(NO_SPLIT)
#decl(SPLIT)
@group(0) @binding(1)
var<storage, read_write> src1: array<{{TYPE}}>;
@group(0) @binding(2)
var<storage, read_write> dst: array<{{TYPE}}>;
@group(0) @binding(3)
var<uniform> params: Params;
fn a_value(base: u32) -> {{TYPE}} {
return src0[base];
}
fn b_value(base: u32) -> {{TYPE}} {
return src1[base];
}
#enddecl(SPLIT)
#end(DECLS)
#define(SHADER)
enable f16;
struct Params {
offset_src0: u32,
offset_src1: u32,
offset_dst: u32,
// Strides (in elements)
stride_src01: u32,
stride_src02: u32,
stride_src03: u32,
stride_src11: u32,
stride_src12: u32,
stride_src13: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
// shape of dst
ne: u32,
ne0: u32,
ne1: u32,
ne2: u32,
swapped: u32,
alpha: f32,
limit: f32,
}
@group(0) @binding(0)
var<storage, read_write> src0: array<{{TYPE}}>;
DECLS
override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
return;
}
var i = gid.x;
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
i = i % (params.ne2 * params.ne1 * params.ne0);
let i2 = i / (params.ne1 * params.ne0);
i = i % (params.ne1 * params.ne0);
let i1 = i / params.ne0;
let i0 = i % params.ne0;
let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0;
let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0;
let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;
dst[i_dst] = op(a_value(i_a), b_value(i_b));
}
#end(SHADER)

View File

@ -0,0 +1,155 @@
enable f16;
#ifdef TYPE_F32
#define DataType f32
#endif
#ifdef TYPE_F16
#define DataType f16
#endif
#ifdef OP_REGLU
fn op(a: DataType, b: DataType) -> DataType {
return max(a, 0) * b;
}
#endif
#ifdef OP_GEGLU
const SQRT_2_OVER_PI: DataType = 0.79788456080286535587989211986876;
const GELU_COEF_A: DataType = 0.044715;
fn op(a: DataType, b: DataType) -> DataType {
let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a);
return 0.5 * a * (2.0 - 2.0/ (exp(2* val) + 1)) * b;
}
#endif
#ifdef OP_SWIGLU
fn op(a: DataType, b: DataType) -> DataType {
return a / (1.0 + exp(-a)) * b;
}
#endif
#ifdef OP_SWIGLU_OAI
fn op(a: f32, b: f32) -> f32 {
let xi = min(a, params.limit);
let gi = max(min(b, params.limit), -params.limit);
var out_glu = xi / (1.0 + exp(-xi * params.alpha));
out_glu = out_glu * (1.0 + gi);
return out_glu;
}
#endif
#ifdef OP_GEGLU_ERF
const p_erf: DataType = 0.3275911;
const a1_erf: DataType = 0.254829592;
const a2_erf: DataType = -0.284496736;
const a3_erf: DataType = 1.421413741;
const a4_erf: DataType = -1.453152027;
const a5_erf: DataType = 1.061405429;
const SQRT_2_INV: DataType = 0.7071067811865476;
fn op(a: DataType, b: DataType) -> DataType {
let a_div_sqr2 = a * SQRT_2_INV;
let sign_x = sign(a_div_sqr2);
let x = abs(a_div_sqr2);
let t = 1.0 / (1.0 + p_erf * x);
let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x));
let erf_approx = sign_x * y;
return 0.5 * a * (1.0 + erf_approx) * b;
}
#endif
#ifdef OP_GEGLU_QUICK
const GELU_QUICK_COEF: DataType = -1.702;
fn op(a: DataType, b: DataType) -> DataType {
return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b;
}
#endif
struct Params {
offset_src0: u32,
offset_src1: u32,
offset_dst: u32,
// Strides (in elements)
stride_src01: u32,
stride_src02: u32,
stride_src03: u32,
stride_src11: u32,
stride_src12: u32,
stride_src13: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
// shape of dst
ne: u32,
ne0: u32,
ne1: u32,
ne2: u32,
swapped: u32,
alpha: f32,
limit: f32,
}
@group(0) @binding(0)
var<storage, read_write> src0: array<DataType>;
#ifdef NO_SPLIT
@group(0) @binding(1)
var<storage, read_write> dst: array<DataType>;
@group(0) @binding(2)
var<uniform> params: Params;
fn a_value(base: u32) -> DataType {
let offset: u32 = select(0, params.ne0, params.swapped != 0);
return src0[base + offset];
}
fn b_value(base: u32) -> DataType {
let offset: u32 = select(params.ne0, 0, params.swapped != 0);
return src0[base + offset];
}
#else
@group(0) @binding(1)
var<storage, read_write> src1: array<DataType>;
@group(0) @binding(2)
var<storage, read_write> dst: array<DataType>;
@group(0) @binding(3)
var<uniform> params: Params;
fn a_value(base: u32) -> DataType {
return src0[base];
}
fn b_value(base: u32) -> DataType {
return src1[base];
}
#endif
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
return;
}
var i = gid.x;
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
i = i % (params.ne2 * params.ne1 * params.ne0);
let i2 = i / (params.ne1 * params.ne0);
i = i % (params.ne1 * params.ne0);
let i1 = i / params.ne0;
let i0 = i % params.ne0;
let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0;
let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0;
let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;
dst[i_dst] = op(a_value(i_a), b_value(i_b));
}

View File

@ -61,10 +61,10 @@ fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u3
#ifdef INIT_SRC0_SHMEM_Q4_0
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 18u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
@ -81,14 +81,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx];
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_src0_f16_at(block_byte_base);
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = src0[scale_idx + 1u + block_offset + j];
let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_src0_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
@ -104,10 +102,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
#ifdef INIT_SRC0_SHMEM_Q4_1
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 20u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
const F16_PER_BLOCK = 10u; // 1 scale + 8 packed weights + 1 mean
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
@ -124,15 +122,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx];
let m = src0[scale_idx + 1u];
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_src0_f16_at(block_byte_base);
let m = load_src0_f16_at(block_byte_base + 2u);
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = src0[scale_idx + 2u + block_offset + j];
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
let q_packed = load_src0_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_lo = f16(q_byte & 0xF) * d + m;
@ -149,11 +145,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
#ifdef INIT_SRC0_SHMEM_Q5_0
// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 22u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
// tile_k is defined as 32u, so blocks_k ends up being 1 always
override BLOCKS_K = TILE_K / BLOCK_SIZE;
const NQ = 16u;
const F16_PER_BLOCK = 11u; // 1 scale + 2 qh + 8 packed weights
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
@ -171,18 +167,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = src0[scale_idx];
let qh0 = src0[scale_idx + 1u];
let qh1 = src0[scale_idx + 2u];
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
let d = load_src0_f16_at(block_byte_base);
let qh_packed = load_src0_u32_at(block_byte_base + 2u);
for (var j = 0u; j < 2; j++) {
let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];
let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u);
let q_packed = load_src0_u32_at(q_byte_offset);
let j_adjusted = j + (block_offset / 2u);
@ -207,11 +199,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
#ifdef INIT_SRC0_SHMEM_Q5_1
// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 24u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
// tile_k is defined as 32u, so blocks_k ends up being 1 always
override BLOCKS_K = TILE_K / BLOCK_SIZE;
const NQ = 16u;
const F16_PER_BLOCK = 12u; // 1 scale + 2 qh + 8 packed weights + 1 mean
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
@ -229,20 +221,16 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = src0[scale_idx];
let m = src0[scale_idx + 1u];
let qh0 = src0[scale_idx + 2u];
let qh1 = src0[scale_idx + 3u];
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
let d = load_src0_f16_at(block_byte_base);
let m = load_src0_f16_at(block_byte_base + 2u);
let qh_packed = load_src0_u32_at(block_byte_base + 4u);
for (var j = 0u; j < 2; j++) {
let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];
let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u);
let q_packed = load_src0_u32_at(q_byte_offset);
let j_adjusted = j + (block_offset / 2u);
@ -266,10 +254,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
#ifdef INIT_SRC0_SHMEM_Q8_0
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 34u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
const F16_PER_BLOCK = 17u; // 1 scale + 16 in array of weights
const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread
@ -286,14 +274,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx];
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_src0_f16_at(block_byte_base);
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
let q_0 = src0[scale_idx + 1u + block_offset + j];
let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_src0_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
@ -308,10 +294,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
#ifdef INIT_SRC0_SHMEM_Q8_1
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 36u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
const F16_PER_BLOCK = 18u; // 1 scale + 1 mean + 8 32-bit values in array of weights
const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread, 2 threads per block
@ -328,15 +314,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let d = src0[scale_idx];
let m = src0[scale_idx + 1u];
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_src0_f16_at(block_byte_base);
let m = load_src0_f16_at(block_byte_base + 2u);
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
let q_0 = src0[scale_idx + 2u + block_offset + j];
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
let q_packed = load_src0_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
@ -351,7 +335,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
#ifdef INIT_SRC0_SHMEM_Q2_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 42u;
const BLOCK_SIZE_BYTES = 84u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
// Use standard thread layout instead of lane/row_group
@ -371,10 +355,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = src0[scale_idx + 40u];
let dmin = src0[scale_idx + 41u];
let d = load_src0_f16_at(block_byte_base + 80u);
let dmin = load_src0_f16_at(block_byte_base + 82u);
// Decode the element at position k_in_block
let block_of_32 = k_in_block / 32u;
@ -387,18 +371,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let is = k_in_block / 16u;
let sc_0 = src0[scale_idx + 2u * (is / 4u)];
let sc_1 = src0[scale_idx + 2u * (is / 4u) + 1u];
let sc_packed = bitcast<u32>(vec2(sc_0, sc_1));
let sc_packed = load_src0_u32_at(block_byte_base + 4u * (is / 4u));
let sc = get_byte(sc_packed, is % 4u);
let dl = d * f16(sc & 0xFu);
let ml = dmin * f16(sc >> 4u);
let q_idx = q_b_idx + k + l;
let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (q_idx / 4u));
let q_byte = get_byte(q_packed, q_idx % 4u);
let qs_val = (q_byte >> shift) & 3u;
@ -410,7 +390,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
#ifdef INIT_SRC0_SHMEM_Q3_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 55u;
const BLOCK_SIZE_BYTES = 110u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
@ -429,9 +409,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = src0[scale_idx + 54u];
let d = load_src0_f16_at(block_byte_base + 108u);
// Load and unpack scales
let kmask1: u32 = 0x03030303u;
@ -439,9 +419,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
var scale_vals: array<u32, 4>;
for (var i: u32 = 0u; i < 4u; i++) {
let scale_0 = src0[scale_idx + 48u + (2u*i)];
let scale_1 = src0[scale_idx + 48u + (2u*i) + 1u];
scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
scale_vals[i] = load_src0_u32_at(block_byte_base + 96u + 4u * i);
}
var tmp: u32 = scale_vals[2];
@ -453,16 +431,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
// Load hmask and qs arrays
var hmask_vals: array<u32, 8>;
for (var i: u32 = 0u; i < 8u; i++) {
let hmask_0 = src0[scale_idx + (2u*i)];
let hmask_1 = src0[scale_idx + (2u*i) + 1u];
hmask_vals[i] = bitcast<u32>(vec2(hmask_0, hmask_1));
hmask_vals[i] = load_src0_u32_at(block_byte_base + 4u * i);
}
var qs_vals: array<u32, 16>;
for (var i: u32 = 0u; i < 16u; i++) {
let qs_0 = src0[scale_idx + 16u + (2u*i)];
let qs_1 = src0[scale_idx + 16u + (2u*i) + 1u];
qs_vals[i] = bitcast<u32>(vec2(qs_0, qs_1));
qs_vals[i] = load_src0_u32_at(block_byte_base + 32u + 4u * i);
}
let half = k_in_block / 128u; // 0 or 1
@ -502,7 +476,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
#ifdef INIT_SRC0_SHMEM_Q4_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 72u;
const BLOCK_SIZE_BYTES = 144u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
@ -521,17 +495,15 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = src0[scale_idx];
let dmin = src0[scale_idx + 1u];
let d = load_src0_f16_at(block_byte_base);
let dmin = load_src0_f16_at(block_byte_base + 2u);
// Load packed scales
var scale_vals: array<u32, 3>;
for (var i: u32 = 0u; i < 3u; i++) {
let scale_0 = src0[scale_idx + 2u + (2u*i)];
let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
scale_vals[i] = load_src0_u32_at(block_byte_base + 4u + 4u * i);
}
// Map k_in_block to loop structure:
@ -567,9 +539,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let ml = dmin * f16(mn);
let q_idx = q_b_idx + l;
let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (q_idx / 4u));
let q_byte = get_byte(q_packed, q_idx % 4u);
let qs_val = (q_byte >> shift) & 0xFu;
@ -582,7 +552,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
#ifdef INIT_SRC0_SHMEM_Q5_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 88u;
const BLOCK_SIZE_BYTES = 176u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
@ -601,17 +571,15 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = src0[scale_idx];
let dmin = src0[scale_idx + 1u];
let d = load_src0_f16_at(block_byte_base);
let dmin = load_src0_f16_at(block_byte_base + 2u);
// Load packed scales
var scale_vals: array<u32, 3>;
for (var i: u32 = 0u; i < 3u; i++) {
let scale_0 = src0[scale_idx + 2u + (2u*i)];
let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
scale_vals[i] = load_src0_u32_at(block_byte_base + 4u + 4u * i);
}
// The original loop processes elements in groups of 64
@ -651,15 +619,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let ml = dmin * f16(mn);
let q_idx = q_b_idx + l;
let q_0 = src0[scale_idx + 24u + 2u * (q_idx / 4u)];
let q_1 = src0[scale_idx + 24u + 2u * (q_idx / 4u) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_packed = load_src0_u32_at(block_byte_base + 48u + 4u * (q_idx / 4u));
let q_byte = get_byte(q_packed, q_idx % 4u);
let qh_0 = src0[scale_idx + 8u + 2u * (l / 4u)];
let qh_1 = src0[scale_idx + 8u + 2u * (l / 4u) + 1u];
let qh_packed = bitcast<u32>(vec2(qh_0, qh_1));
let qh_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (l / 4u));
let qh_byte = get_byte(qh_packed, l % 4u);
@ -675,7 +639,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
#ifdef INIT_SRC0_SHMEM_Q6_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 105u;
const BLOCK_SIZE_BYTES = 210u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
@ -694,7 +658,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let scale_idx = src0_idx * F16_PER_BLOCK;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let half = k_in_block / 128u;
let pos_in_half = k_in_block % 128u;
@ -707,30 +671,18 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
// Load only ql13 word needed
let ql13_flat = ql_b_idx + l;
let ql13_word = ql13_flat / 4u;
let ql13 = bitcast<u32>(vec2(
src0[scale_idx + 2u * ql13_word],
src0[scale_idx + 2u * ql13_word + 1u]
));
let ql13_b = get_byte(ql13, ql13_flat % 4u);
let ql13 = load_src0_u32_at(block_byte_base + ql13_flat);
let ql13_b = get_byte(ql13, 0u);
// Load only ql24 word needed
let ql24_flat = ql_b_idx + l + 32u;
let ql24_word = ql24_flat / 4u;
let ql24 = bitcast<u32>(vec2(
src0[scale_idx + 2u * ql24_word],
src0[scale_idx + 2u * ql24_word + 1u]
));
let ql24_b = get_byte(ql24, ql24_flat % 4u);
let ql24 = load_src0_u32_at(block_byte_base + ql24_flat);
let ql24_b = get_byte(ql24, 0u);
// Load only qh word needed
let qh_flat = qh_b_idx + l;
let qh_word = qh_flat / 4u;
let qh = bitcast<u32>(vec2(
src0[scale_idx + 64u + 2u * qh_word],
src0[scale_idx + 64u + 2u * qh_word + 1u]
));
let qh_b = get_byte(qh, qh_flat % 4u);
let qh = load_src0_u32_at(block_byte_base + 128u + qh_flat);
let qh_b = get_byte(qh, 0u);
let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0);
let q2 = f16((ql24_b & 0xFu) | (((qh_b >> 2u) & 3u) << 4u)) - f16(32.0);
@ -740,14 +692,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
// Load only the scale word needed
let is = l / 16u;
let sc_idx = sc_b_idx + is + quarter * 2u;
let sc_word = sc_idx / 4u;
let sc = bitcast<u32>(vec2(
src0[scale_idx + 96u + 2u * sc_word],
src0[scale_idx + 96u + 2u * sc_word + 1u]
));
let sc_val = get_byte_i32(sc, sc_idx % 4u);
let sc = load_src0_u32_at(block_byte_base + 192u + sc_idx);
let sc_val = get_byte_i32(sc, 0u);
let d = src0[scale_idx + 104u];
let d = load_src0_f16_at(block_byte_base + 208u);
var q_val: f16;
if (quarter == 0u) {

View File

@ -52,8 +52,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
#ifdef MUL_ACC_Q4_0
const BLOCK_SIZE = 32;
const BLOCK_SIZE_BYTES = 18u;
const NQ = 16u; // number of weights per thread
const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
@ -62,14 +62,13 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(src0[scale_idx]);
let d = f32(load_src0_f16_at(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = src0[scale_idx + 1 + block_offset + j];
let q_1 = src0[scale_idx + 1 + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_src0_u32_at(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d;
@ -86,8 +85,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
#ifdef MUL_ACC_Q4_1
const BLOCK_SIZE = 32;
const BLOCK_SIZE_BYTES = 20u;
const NQ = 16u; // number of weights per thread
const F16_PER_BLOCK = 10u;
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
@ -96,15 +95,14 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(src0[scale_idx]);
let m = f32(src0[scale_idx + 1u]);
let d = f32(load_src0_f16_at(block_byte_base));
let m = f32(load_src0_f16_at(block_byte_base + 2u));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = src0[scale_idx + 2u + block_offset + j];
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
let q_packed = load_src0_u32_at(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
@ -121,8 +119,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
#ifdef MUL_ACC_Q5_0
const BLOCK_SIZE = 32;
const BLOCK_SIZE_BYTES = 22u;
const NQ = 16u; // number of weights per thread
const F16_PER_BLOCK = 11u;
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
@ -131,18 +129,15 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(src0[scale_idx]);
let qh0 = src0[scale_idx + 1u];
let qh1 = src0[scale_idx + 2u];
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
let d = f32(load_src0_f16_at(block_byte_base));
let qh_packed = load_src0_u32_at(block_byte_base + 2u);
for (var j = 0u; j < 2; j++) {
let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];
let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u);
let q_packed = load_src0_u32_at(q_byte_offset);
let j_adjusted = j + (block_offset / 2u);
@ -168,8 +163,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
#ifdef MUL_ACC_Q5_1
const BLOCK_SIZE = 32;
const BLOCK_SIZE_BYTES = 24u;
const NQ = 16u; // number of weights per thread
const F16_PER_BLOCK = 12u;
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
@ -178,19 +173,16 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(src0[scale_idx]);
let m = src0[scale_idx + 1u];
let qh0 = src0[scale_idx + 2u];
let qh1 = src0[scale_idx + 3u];
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
let d = f32(load_src0_f16_at(block_byte_base));
let m = load_src0_f16_at(block_byte_base + 2u);
let qh_packed = load_src0_u32_at(block_byte_base + 4u);
for (var j = 0u; j < 2; j++) {
let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];
let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u);
let q_packed = load_src0_u32_at(q_byte_offset);
let j_adjusted = j + (block_offset / 2u);
@ -216,8 +208,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
#ifdef MUL_ACC_Q8_0
const BLOCK_SIZE = 32;
const BLOCK_SIZE_BYTES = 34u;
const NQ = 16u; // number of weights per thread
const F16_PER_BLOCK = 17u;
const WEIGHTS_PER_F16 = 2u;
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
@ -226,15 +218,14 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(src0[scale_idx]);
let d = f32(load_src0_f16_at(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = src0[scale_idx + 1 + block_offset + j];
let q_1 = src0[scale_idx + 1 + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_src0_u32_at(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f32(q_byte) * d;
@ -250,8 +241,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
#ifdef MUL_ACC_Q8_1
const BLOCK_SIZE = 32;
const BLOCK_SIZE_BYTES = 36u;
const NQ = 16u; // number of weights per thread
const F16_PER_BLOCK = 18u;
const WEIGHTS_PER_F16 = 2u;
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
@ -260,16 +251,15 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(src0[scale_idx]);
let m = src0[scale_idx + 1u];
let d = f32(load_src0_f16_at(block_byte_base));
let m = load_src0_f16_at(block_byte_base + 2u);
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = src0[scale_idx + 2u + block_offset + j];
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
let q_packed = load_src0_u32_at(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f32(q_byte) * d + f32(m);
@ -284,13 +274,7 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
#ifdef MUL_ACC_Q6_K
const BLOCK_SIZE = 256u;
const F16_PER_BLOCK = 105u;
fn load_u32_at(bbase: u32, byte_offset: u32) -> u32 {
let aligned = byte_offset & ~3u;
let idx = bbase + aligned / 2u;
return bitcast<u32>(vec2(src0[idx], src0[idx + 1u]));
}
const BLOCK_SIZE_BYTES = 210u;
fn byte_of(v: u32, b: u32) -> u32 {
return (v >> (b * 8u)) & 0xFFu;
@ -323,16 +307,15 @@ fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
var local_sum = 0.0;
for (var i = ix; i < nb; i += 2u) {
let bbase = (idx_base + k_block_start + i) * F16_PER_BLOCK;
let bbase = (idx_base + k_block_start + i) * BLOCK_SIZE_BYTES;
let d_raw = load_u32_at(bbase, 208u);
let d = f32(bitcast<vec2<f16>>(d_raw)[0]);
let d = f32(load_src0_f16_at(bbase + 208u));
let ql1_u32 = load_u32_at(bbase, q_offset_l);
let ql2_u32 = load_u32_at(bbase, q_offset_l + 32u);
let qh_u32 = load_u32_at(bbase, 128u + q_offset_h);
let sc_u32_0 = load_u32_at(bbase, sc_base_byte);
let sc_u32_1 = load_u32_at(bbase, sc_base_byte + 4u);
let ql1_u32 = load_src0_u32_at(bbase + q_offset_l);
let ql2_u32 = load_src0_u32_at(bbase + q_offset_l + 32u);
let qh_u32 = load_src0_u32_at(bbase + 128u + q_offset_h);
let sc_u32_0 = load_src0_u32_at(bbase + sc_base_byte);
let sc_u32_1 = load_src0_u32_at(bbase + sc_base_byte + 4u);
let sc0 = sbyte_of(sc_u32_0, sc_byte_pos);
let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u);

View File

@ -1,138 +1,12 @@
#define(VARIANTS)
[
{
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"]
},
{
"SHADER_SUFFIX": "f32_inplace",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"]
},
{
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"]
},
{
"SHADER_SUFFIX": "f16_inplace",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"]
},
{
"SHADER_SUFFIX": "f32_ff",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"]
},
{
"SHADER_SUFFIX": "f32_ff_inplace",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"]
},
{
"SHADER_SUFFIX": "f16_ff",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"]
},
{
"SHADER_SUFFIX": "f16_ff_inplace",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"]
}
]
#end(VARIANTS)
#define(DECLS)
#decl(ROTATE)
fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
dst[i_dst0] = {{TYPE}}(out0);
dst[i_dst1] = {{TYPE}}(out1);
}
#enddecl(ROTATE)
#decl(ROTATE_INPLACE)
fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
src0[i_dst0] = {{TYPE}}(out0);
src0[i_dst1] = {{TYPE}}(out1);
}
#enddecl(ROTATE_INPLACE)
#decl(NO_FF_FUNC)
fn freq_factor(i: u32) -> f32 {
return 1.0f;
}
#enddecl(NO_FF_FUNC)
#decl(FF_FUNC)
fn freq_factor(i: u32) -> f32 {
return src2[params.offset_src2 + i/2];
}
#enddecl(FF_FUNC)
#decl(NO_FF_BINDINGS)
@group(0) @binding(2)
var<storage, read_write> dst: array<{{TYPE}}>;
@group(0) @binding(3)
var<uniform> params: Params;
#enddecl(NO_FF_BINDINGS)
#decl(NO_FF_BINDINGS_INPLACE)
@group(0) @binding(2)
var<uniform> params: Params;
#enddecl(NO_FF_BINDINGS_INPLACE)
#decl(FF_BINDINGS)
@group(0) @binding(2)
var<storage, read_write> src2: array<f32>;
@group(0) @binding(3)
var<storage, read_write> dst: array<{{TYPE}}>;
@group(0) @binding(4)
var<uniform> params: Params;
#enddecl(FF_BINDINGS)
#decl(FF_BINDINGS_INPLACE)
@group(0) @binding(2)
var<storage, read_write> src2: array<f32>;
@group(0) @binding(3)
var<uniform> params: Params;
#enddecl(FF_BINDINGS_INPLACE)
#end(DECLS)
#define(SHADER)
enable f16;
#ifdef TYPE_F32
#define DataType f32
#endif
#ifdef TYPE_F16
#define DataType f16
#endif
struct Params {
offset_src0: u32,
offset_src1: u32,
@ -168,12 +42,69 @@ struct Params {
};
@group(0) @binding(0)
var<storage, read_write> src0: array<{{TYPE}}>;
var<storage, read_write> src0: array<DataType>;
@group(0) @binding(1)
var<storage, read_write> src1: array<i32>;
DECLS
#ifdef INPLACE
#ifdef FF_FUNC
@group(0) @binding(2)
var<storage, read_write> src2: array<f32>;
@group(0) @binding(3)
var<uniform> params: Params;
#else
@group(0) @binding(2)
var<uniform> params: Params;
#endif
#else
#ifdef FF_FUNC
@group(0) @binding(2)
var<storage, read_write> src2: array<f32>;
@group(0) @binding(3)
var<storage, read_write> dst: array<DataType>;
@group(0) @binding(4)
var<uniform> params: Params;
#else
@group(0) @binding(2)
var<storage, read_write> dst: array<DataType>;
@group(0) @binding(3)
var<uniform> params: Params;
#endif
#endif
#ifdef FF_FUNC
fn freq_factor(i: u32) -> f32 {
return src2[params.offset_src2 + i/2];
}
#else
fn freq_factor(i: u32) -> f32 {
return 1.0f;
}
#endif
#ifdef INPLACE
fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
src0[i_dst0] = DataType(out0);
src0[i_dst1] = DataType(out1);
}
#else
fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
dst[i_dst0] = DataType(out0);
dst[i_dst1] = DataType(out1);
}
#endif
fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 {
let y = (f32(i / 2) - low) / max(0.001f, high - low);
@ -184,7 +115,7 @@ fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 {
// TODO: check performance of instantiating once on the CPU and passed as buffer, since it's repeated per-row
fn rope_yarn(theta_extrap: f32, i: u32) -> vec2<f32> {
var mscale = params.attn_factor;
var theta = params.freq_scale * theta_extrap;
var theta = params.freq_scale * theta_extrap;
if (params.ext_factor != 0.0f) {
let ramp_mix = rope_yarn_ramp(params.corr_dim0, params.corr_dim1, i) * params.ext_factor;
theta = theta * (1 - ramp_mix) + theta_extrap * ramp_mix;
@ -211,10 +142,9 @@ fn pair_offset(is_neox: bool, is_mrope: bool, is_vision: bool) -> u32 {
}
}
override wg_size: u32;
@compute @workgroup_size(wg_size)
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
// two elements per thread
// two elements per n_threads
if (gid.x >= params.n_threads) {
return;
}
@ -290,6 +220,5 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let x0 = f32(src0[i_src]);
let x1 = f32(src0[i_src + pair_offset(is_neox, is_mrope, is_vision)]);
rotate(i_dst, i_dst + pair_offset(is_neox, is_mrope, is_vision), x0 * thetas.x - x1 * thetas.y, x0 * thetas.y + x1 * thetas.x);
}
#end(SHADER)
}

View File

@ -1,215 +1,12 @@
#define(VARIANTS)
[
{
"SHADER_NAME": "soft_max_f32",
"DECLS": ["BASE_BINDINGS", "NOT_INPLACE", "NO_MASK", "NO_SINK"]
},
{
"SHADER_NAME": "soft_max_f32_inplace",
"DECLS": ["BASE_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "NO_SINK"]
},
{
"SHADER_NAME": "soft_max_f32_sink",
"DECLS": ["SINK_BINDINGS", "NOT_INPLACE", "NO_MASK", "SINK"]
},
{
"SHADER_NAME": "soft_max_f32_sink_inplace",
"DECLS": ["SINK_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "SINK"]
},
{
"SHADER_NAME": "soft_max_f32_mask_f32",
"REPLS": {
"MASK_TYPE" : "f32",
},
"DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"]
},
{
"SHADER_NAME": "soft_max_f32_mask_f32_inplace",
"REPLS": {
"MASK_TYPE" : "f32",
},
"DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"]
},
{
"SHADER_NAME": "soft_max_f32_mask_f16",
"REPLS": {
"MASK_TYPE" : "f16",
},
"DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"]
},
{
"SHADER_NAME": "soft_max_f32_mask_f16_inplace",
"REPLS": {
"MASK_TYPE" : "f16",
},
"DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"]
},
{
"SHADER_NAME": "soft_max_f32_mask_f32_sink",
"REPLS": {
"MASK_TYPE" : "f32",
},
"DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"]
},
{
"SHADER_NAME": "soft_max_f32_mask_f32_sink_inplace",
"REPLS": {
"MASK_TYPE" : "f32",
},
"DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"]
},
{
"SHADER_NAME": "soft_max_f32_mask_f16_sink",
"REPLS": {
"MASK_TYPE" : "f16",
},
"DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"]
},
{
"SHADER_NAME": "soft_max_f32_mask_f16_sink_inplace",
"REPLS": {
"MASK_TYPE" : "f16",
},
"DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"]
}
]
#end(VARIANTS)
#define(DECLS)
#decl(BASE_BINDINGS)
@group(0) @binding(1)
var<storage, read_write> dst: array<f32>;
@group(0) @binding(2)
var<uniform> params: Params;
#enddecl(BASE_BINDINGS)
#decl(BASE_BINDINGS_INPLACE)
@group(0) @binding(1)
var<uniform> params: Params;
#enddecl(BASE_BINDINGS_INPLACE)
#decl(SINK_BINDINGS)
@group(0) @binding(1)
var<storage, read_write> sinks: array<f32>;
@group(0) @binding(2)
var<storage, read_write> dst: array<f32>;
@group(0) @binding(3)
var<uniform> params: Params;
#enddecl(SINK_BINDINGS)
#decl(SINK_BINDINGS_INPLACE)
@group(0) @binding(1)
var<storage, read_write> sinks: array<f32>;
@group(0) @binding(2)
var<uniform> params: Params;
#enddecl(SINK_BINDINGS_INPLACE)
#decl(MASK_BINDINGS)
@group(0) @binding(1)
var<storage, read_write> mask: array<{{MASK_TYPE}}>;
@group(0) @binding(2)
var<storage, read_write> dst: array<f32>;
@group(0) @binding(3)
var<uniform> params: Params;
#enddecl(MASK_BINDINGS)
#decl(MASK_BINDINGS_INPLACE)
@group(0) @binding(1)
var<storage, read_write> mask: array<{{MASK_TYPE}}>;
@group(0) @binding(2)
var<uniform> params: Params;
#enddecl(MASK_BINDINGS_INPLACE)
#decl(MASK_SINK_BINDINGS)
@group(0) @binding(1)
var<storage, read_write> mask: array<{{MASK_TYPE}}>;
@group(0) @binding(2)
var<storage, read_write> sinks: array<f32>;
@group(0) @binding(3)
var<storage, read_write> dst: array<f32>;
@group(0) @binding(4)
var<uniform> params: Params;
#enddecl(MASK_SINK_BINDINGS)
#decl(MASK_SINK_BINDINGS_INPLACE)
@group(0) @binding(1)
var<storage, read_write> mask: array<{{MASK_TYPE}}>;
@group(0) @binding(2)
var<storage, read_write> sinks: array<f32>;
@group(0) @binding(3)
var<uniform> params: Params;
#enddecl(MASK_SINK_BINDINGS_INPLACE)
#decl(NOT_INPLACE)
fn inter_value(i: u32) -> f32 {
return dst[i];
}
fn update(i: u32, val: f32) {
dst[i] = val;
}
#enddecl(NOT_INPLACE)
#decl(INPLACE)
fn inter_value(i: u32) -> f32 {
return src[i];
}
fn update(i: u32, val: f32) {
src[i] = val;
}
#enddecl(INPLACE)
#decl(NO_MASK)
fn mask_val(i: u32) -> f32 {
return 0.0;
}
#enddecl(NO_MASK)
#decl(MASK)
fn mask_val(i: u32) -> f32 {
return f32(mask[i]);
}
#enddecl(MASK)
#decl(NO_SINK)
fn lower_max_bound(i2: u32) -> f32 {
return -1e30;
}
fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
return val;
}
#enddecl(NO_SINK)
#decl(SINK)
fn lower_max_bound(i2: u32) -> f32 {
return sinks[params.offset_sinks + i2];
}
fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
return val + exp(sinks[params.offset_sinks + i2] - max_val);
}
#enddecl(SINK)
#end(DECLS)
#define(SHADER)
enable f16;
#ifdef MASK_F32
#define MaskType f32
#endif
#ifdef MASK_F16
#define MaskType f16
#endif
struct Params {
offset_src0: u32,
offset_src1: u32,
@ -249,14 +46,117 @@ struct Params {
@group(0) @binding(0)
var<storage, read_write> src: array<f32>;
DECLS
#ifdef HAS_MASK
#ifdef HAS_SINK
@group(0) @binding(1)
var<storage, read_write> mask: array<MaskType>;
@group(0) @binding(2)
var<storage, read_write> sinks: array<f32>;
#ifdef INPLACE
@group(0) @binding(3)
var<uniform> params: Params;
#else
@group(0) @binding(3)
var<storage, read_write> dst: array<f32>;
@group(0) @binding(4)
var<uniform> params: Params;
#endif
#else
@group(0) @binding(1)
var<storage, read_write> mask: array<MaskType>;
#ifdef INPLACE
@group(0) @binding(2)
var<uniform> params: Params;
#else
@group(0) @binding(2)
var<storage, read_write> dst: array<f32>;
@group(0) @binding(3)
var<uniform> params: Params;
#endif
#endif
#else
#ifdef HAS_SINK
@group(0) @binding(1)
var<storage, read_write> sinks: array<f32>;
#ifdef INPLACE
@group(0) @binding(2)
var<uniform> params: Params;
#else
@group(0) @binding(2)
var<storage, read_write> dst: array<f32>;
@group(0) @binding(3)
var<uniform> params: Params;
#endif
#else
#ifdef INPLACE
@group(0) @binding(1)
var<uniform> params: Params;
#else
@group(0) @binding(1)
var<storage, read_write> dst: array<f32>;
@group(0) @binding(2)
var<uniform> params: Params;
#endif
#endif
#endif
#ifdef INPLACE
fn inter_value(i: u32) -> f32 {
return src[i];
}
fn update(i: u32, val: f32) {
src[i] = val;
}
#else
fn inter_value(i: u32) -> f32 {
return dst[i];
}
fn update(i: u32, val: f32) {
dst[i] = val;
}
#endif
#ifdef HAS_MASK
fn mask_val(i: u32) -> f32 {
return f32(mask[i]);
}
#else
fn mask_val(i: u32) -> f32 {
return 0.0;
}
#endif
#ifdef HAS_SINK
fn lower_max_bound(i2: u32) -> f32 {
return sinks[params.offset_sinks + i2];
}
fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
return val + exp(sinks[params.offset_sinks + i2] - max_val);
}
#else
fn lower_max_bound(i2: u32) -> f32 {
return -1e30;
}
fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
return val;
}
#endif
const CACHE_SIZE: u32 = 16;
var<workgroup> scratch: array<f32, WG_SIZE>;
override wg_size: u32;
var<workgroup> scratch: array<f32, wg_size>;
@compute @workgroup_size(wg_size)
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
@ -268,7 +168,7 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
let i_src0_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01;
let i_src1_row = params.offset_src1 + (i3 % params.ne13) * params.stride_src13 + (i2 % params.ne12) * params.stride_src12 + i1 * params.stride_src11;
let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
let elems = (params.ne0 + wg_size - 1) / wg_size;
let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE;
let head = f32(i2);
let slope = select(1, select(pow(params.m1, 2 * (head - params.n_head_log2) + 1), pow(params.m0, head + 1), head < params.n_head_log2), params.max_bias > 0);
@ -286,12 +186,12 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
if (col < CACHE_SIZE) {
cache[col] = val;
}
col += wg_size;
col += WG_SIZE;
}
scratch[lid.x] = max_val;
workgroupBarrier();
var offset = wg_size / 2;
var offset: u32 = WG_SIZE / 2;
while (offset > 0) {
if (lid.x < offset) {
scratch[lid.x] = max(scratch[lid.x], scratch[lid.x + offset]);
@ -317,12 +217,12 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
} else {
update(i_dst_row + col, ex);
}
col += wg_size;
col += WG_SIZE;
}
scratch[lid.x] = sum;
workgroupBarrier();
offset = wg_size / 2;
offset = WG_SIZE / 2;
while (offset > 0) {
if (lid.x < offset) {
scratch[lid.x] += scratch[lid.x + offset];
@ -339,7 +239,7 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
break;
}
update(i_dst_row + col, select(inter_value(i_dst_row + col), cache[col], col < CACHE_SIZE) * sum_recip);
col += wg_size;
col += WG_SIZE;
}
}
#end(SHADER)

View File

@ -380,22 +380,33 @@ extern "C" {
size_t n_samplers;
};
struct llama_model_tensor_override {
const char * pattern;
enum ggml_type type;
};
struct llama_model_imatrix_data {
const char * name;
const float * data;
size_t size;
};
// model quantization parameters
typedef struct llama_model_quantize_params {
int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
enum llama_ftype ftype; // quantize to this llama_ftype
enum ggml_type output_tensor_type; // output tensor type
enum ggml_type token_embedding_type; // token embeddings tensor type
bool allow_requantize; // allow quantizing non-f32/f16 tensors
bool quantize_output_tensor; // quantize output.weight
bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
bool pure; // quantize all tensors to the default type
bool keep_split; // quantize to the same number of shards
bool dry_run; // calculate and show the final quantization size without performing quantization
void * imatrix; // pointer to importance matrix data
void * kv_overrides; // pointer to vector containing overrides
void * tensor_types; // pointer to vector containing tensor types
void * prune_layers; // pointer to vector containing layer indices to prune
int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
enum llama_ftype ftype; // quantize to this llama_ftype
enum ggml_type output_tensor_type; // output tensor type
enum ggml_type token_embedding_type; // token embeddings tensor type
bool allow_requantize; // allow quantizing non-f32/f16 tensors
bool quantize_output_tensor; // quantize output.weight
bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
bool pure; // quantize all tensors to the default type
bool keep_split; // quantize to the same number of shards
bool dry_run; // calculate and show the final quantization size without performing quantization
const struct llama_model_imatrix_data * imatrix; // pointer to importance matrix data
const struct llama_model_kv_override * kv_overrides; // pointer to kv overrides
const struct llama_model_tensor_override * tt_overrides; // pointer to tensor overrides
const int32_t * prune_layers; // pointer to layer indices to prune
} llama_model_quantize_params;
typedef struct llama_logit_bias {

View File

@ -0,0 +1,45 @@
{{- bos_token -}}
{%- set keep_past_thinking = keep_past_thinking | default(false) -%}
{%- set ns = namespace(system_prompt="") -%}
{%- if messages[0]["role"] == "system" -%}
{%- set ns.system_prompt = messages[0]["content"] -%}
{%- set messages = messages[1:] -%}
{%- endif -%}
{%- if tools -%}
{%- set ns.system_prompt = ns.system_prompt + ("\n" if ns.system_prompt else "") + "List of tools: [" -%}
{%- for tool in tools -%}
{%- if tool is not string -%}
{%- set tool = tool | tojson -%}
{%- endif -%}
{%- set ns.system_prompt = ns.system_prompt + tool -%}
{%- if not loop.last -%}
{%- set ns.system_prompt = ns.system_prompt + ", " -%}
{%- endif -%}
{%- endfor -%}
{%- set ns.system_prompt = ns.system_prompt + "]" -%}
{%- endif -%}
{%- if ns.system_prompt -%}
{{- "<|im_start|>system\n" + ns.system_prompt + "<|im_end|>\n" -}}
{%- endif -%}
{%- set ns.last_assistant_index = -1 -%}
{%- for message in messages -%}
{%- if message["role"] == "assistant" -%}
{%- set ns.last_assistant_index = loop.index0 -%}
{%- endif -%}
{%- endfor -%}
{%- for message in messages -%}
{{- "<|im_start|>" + message["role"] + "\n" -}}
{%- set content = message["content"] -%}
{%- if content is not string -%}
{%- set content = content | tojson -%}
{%- endif -%}
{%- if message["role"] == "assistant" and not keep_past_thinking and loop.index0 != ns.last_assistant_index -%}
{%- if "</think>" in content -%}
{%- set content = content.split("</think>")[-1] | trim -%}
{%- endif -%}
{%- endif -%}
{{- content + "<|im_end|>\n" -}}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{- "<|im_start|>assistant\n" -}}
{%- endif -%}

View File

@ -139,7 +139,11 @@ def main():
'_ZL18flash_attn_ext_f16ILi96ELi96ELi4ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL18flash_attn_ext_vecILi128ELi2EL9ggml_type2ELS0_2ELb0EEvPKcS2_S2_S2_S2_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS6_IjLj3EEiiiiiiiiiiiliiliiiiil',
'_ZL9mul_mat_qIL9ggml_type10ELi16ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type12ELi128ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii'
'_ZL9mul_mat_qIL9ggml_type12ELi128ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type40ELi112ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type40ELi112ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type40ELi128ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
'_ZL9mul_mat_qIL9ggml_type40ELi128ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii'
}
functions = parse_log_file(log_file)

File diff suppressed because it is too large Load Diff

View File

@ -1 +1 @@
a04eea0761a85d18f3f504d6ab970c5c9dce705f
50634c28837c24ac68b380b5750b41e701c87d73

View File

@ -19,7 +19,7 @@
// dedup helpers
static ggml_tensor * build_kq_mask(
static ggml_tensor * build_attn_inp_kq_mask(
ggml_context * ctx,
const llama_kv_cache_context * mctx,
const llama_ubatch & ubatch,
@ -28,7 +28,11 @@ static ggml_tensor * build_kq_mask(
const auto n_tokens = ubatch.n_tokens;
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
ggml_tensor * res = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
ggml_set_input(res);
ggml_set_name(res, "attn_inp_kq_mask");
return res;
}
static bool can_reuse_kq_mask(
@ -52,6 +56,21 @@ static bool can_reuse_kq_mask(
// impl
static ggml_tensor * ggml_mul_mat_aux(
ggml_context * ctx,
ggml_tensor * cur,
ggml_tensor * rot) {
const auto n = rot->ne[0];
ggml_tensor * res;
res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n);
res = ggml_mul_mat (ctx, rot, res);
res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]);
return res;
}
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
if (ubatch->token) {
const int64_t n_tokens = ubatch->n_tokens;
@ -429,6 +448,14 @@ void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
mctx->set_input_v_idxs(self_v_idxs, ubatch);
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
if (self_k_rot) {
mctx->set_input_k_rot(self_k_rot);
}
if (self_v_rot) {
mctx->set_input_v_rot(self_v_rot);
}
}
bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
@ -476,6 +503,14 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
if (self_k_rot) {
mctx->get_base()->set_input_k_rot(self_k_rot);
}
if (self_v_rot) {
mctx->get_base()->set_input_v_rot(self_v_rot);
}
}
bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
@ -532,6 +567,14 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
if (inp_attn->self_k_rot) {
mctx->get_attn()->set_input_k_rot(inp_attn->self_k_rot);
}
if (inp_attn->self_v_rot) {
mctx->get_attn()->set_input_v_rot(inp_attn->self_v_rot);
}
const int64_t n_rs = mctx->get_recr()->get_n_rs();
if (inp_rs->s_copy) {
@ -630,6 +673,14 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
}
if (inp_attn->self_k_rot) {
attn_ctx->get_base()->set_input_k_rot(inp_attn->self_k_rot);
}
if (inp_attn->self_v_rot) {
attn_ctx->get_base()->set_input_v_rot(inp_attn->self_v_rot);
}
const int64_t n_rs = mctx->get_recr()->get_n_rs();
if (inp_rs->s_copy) {
@ -2002,13 +2053,13 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
}
inp->self_k_rot = mctx_cur->build_input_k_rot(ctx0);
inp->self_v_rot = mctx_cur->build_input_v_rot(ctx0);
return inp;
}
@ -2034,6 +2085,15 @@ ggml_tensor * llm_graph_context::build_attn(
int il) const {
GGML_ASSERT(v_mla == nullptr);
if (inp->self_k_rot) {
q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot);
k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot);
}
if (inp->self_v_rot) {
v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot);
}
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
// expand k later to enable rope fusion which directly writes into k-v cache
@ -2061,6 +2121,10 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
cb(cur, "kqv_out", il);
if (inp->self_v_rot) {
cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot);
}
if (wo) {
cur = build_lora_mm(wo, cur);
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) {
@ -2090,9 +2154,7 @@ static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl(
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
}
@ -2171,6 +2233,18 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * v_mla,
float kq_scale,
int il) const {
if (inp->self_k_rot) {
q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot);
if (k_cur) {
k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot);
}
}
if (inp->self_v_rot) {
if (v_cur) {
v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot);
}
}
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
ggml_build_forward_expand(gf, q_cur);
@ -2211,6 +2285,10 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
cb(cur, "kqv_out", il);
if (inp->self_v_rot) {
cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot);
}
if (wo) {
cur = build_lora_mm(wo, cur);
}
@ -2293,12 +2371,8 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams);
ggml_set_input(inp->self_kq_mask);
ggml_set_name(inp->self_kq_mask, "self_kq_mask");
inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv");
}
{
@ -2307,14 +2381,13 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask_swa = build_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams);
ggml_set_input(inp->self_kq_mask_swa);
ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa");
inp->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams);
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv");
}
inp->self_k_rot = mctx_cur->get_base()->build_input_k_rot(ctx0);
inp->self_v_rot = mctx_cur->get_base()->build_input_v_rot(ctx0);
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
}
@ -2473,9 +2546,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa()
inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch);
inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
inp_attn->self_kq_mask = build_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams);
ggml_set_input(inp_attn->self_kq_mask);
inp_attn->self_kq_mask = build_attn_inp_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams);
inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask;
}
@ -2483,9 +2554,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa()
inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch);
inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
inp_attn->self_kq_mask_swa = build_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams);
ggml_set_input(inp_attn->self_kq_mask_swa);
inp_attn->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams);
inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa;
}

View File

@ -308,6 +308,10 @@ public:
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
// note: assumes v_rot^ == I
ggml_tensor * self_k_rot = nullptr;
ggml_tensor * self_v_rot = nullptr;
// note: these have to be copies because in order to be able to reuse a graph, its inputs
// need to carry these parameters with them. otherwise, they can point to freed
// llm_graph_params from a previous batch, causing stack-use-after-return
@ -384,6 +388,10 @@ public:
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
// note: using same rotation matrices for both base and swa cache
ggml_tensor * self_k_rot = nullptr;
ggml_tensor * self_v_rot = nullptr;
const llama_hparams hparams;
const llama_cparams cparams;

View File

@ -13,6 +13,65 @@
#include <map>
#include <stdexcept>
static bool ggml_is_power_of_2(int n) {
return (n & (n - 1)) == 0;
}
// orthonormal Walsh-Hadamard rotation matrix
// note: res^2 == I
static void ggml_gen_hadamard(ggml_tensor * tensor) {
assert(tensor->type == GGML_TYPE_F32);
const int n = tensor->ne[0];
assert(ggml_is_power_of_2(n));
assert(tensor->ne[1] == n);
assert(tensor->ne[2] == 1);
assert(tensor->ne[3] == 1);
std::vector<float> data_f32;
float * data = (float *) tensor->data;
if (tensor->type != GGML_TYPE_F32) {
data_f32.resize(n*n);
data = data_f32.data();
}
data[0*n + 0] = 1.0 / sqrtf(n);
for (int s = 1; s < n; s *= 2) {
for (int i = 0; i < s; i++) {
for (int j = 0; j < s; j++) {
const float val = data[i*n + j];
data[(i + s)*n + (j )] = val;
data[(i )*n + (j + s)] = val;
data[(i + s)*n + (j + s)] = -val;
}
}
}
if (tensor->type != GGML_TYPE_F32) {
ggml_quantize_chunk(tensor->type, data, tensor->data, 0, 1, n*n, nullptr);
}
}
static ggml_tensor * ggml_mul_mat_aux(
ggml_context * ctx,
ggml_tensor * cur,
ggml_tensor * rot) {
const auto n = rot->ne[0];
ggml_tensor * res;
res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n);
res = ggml_mul_mat (ctx, rot, res);
res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]);
return res;
}
//
// llama_kv_cache
//
@ -209,6 +268,48 @@ llama_kv_cache::llama_kv_cache(
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
}
const char * LLAMA_ATTN_ROT_DISABLE = getenv("LLAMA_ATTN_ROT_DISABLE");
const bool attn_rot_disable = LLAMA_ATTN_ROT_DISABLE ? atoi(LLAMA_ATTN_ROT_DISABLE) : false;
if (attn_rot_disable) {
LLAMA_LOG_WARN("%s: attention rotation force disabled (LLAMA_ATTN_ROT_DISABLE)\n", __func__);
}
attn_rot_k =
!attn_rot_disable &&
ggml_is_quantized(type_k) &&
!hparams.is_n_embd_k_gqa_variable() &&
hparams.n_embd_head_k() % 64 == 0;
attn_rot_v =
!attn_rot_disable &&
ggml_is_quantized(type_v) &&
!hparams.is_n_embd_v_gqa_variable() &&
hparams.n_embd_head_v() % 64 == 0;
LLAMA_LOG_INFO("%s: attn_rot_k = %d\n", __func__, attn_rot_k);
LLAMA_LOG_INFO("%s: attn_rot_v = %d\n", __func__, attn_rot_v);
// pre-compute the haramard matrices and keep them in host memory
// TODO: in the future, we can make copies in the backend buffers to avoid host -> device transfers
if (attn_rot_k || attn_rot_v) {
for (int64_t n = 64; n <= std::max(hparams.n_embd_head_k(), hparams.n_embd_head_v()); n *= 2) {
attn_rot_hadamard[n] = std::vector<float>(n*n);
ggml_init_params params = {
/* .mem_size = */ 1*ggml_tensor_overhead(),
/* .mem_buffer = */ nullptr,
/* .no_alloc = */ true,
};
ggml_context_ptr ctx { ggml_init(params) };
ggml_tensor * tmp = ggml_new_tensor_2d(ctx.get(), GGML_TYPE_F32, n, n);
tmp->data = attn_rot_hadamard[n].data();
ggml_gen_hadamard(tmp);
}
}
const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
}
@ -1004,6 +1105,14 @@ bool llama_kv_cache::get_has_shift() const {
return result;
}
ggml_type llama_kv_cache::type_k() const {
return layers[0].k->type;
}
ggml_type llama_kv_cache::type_v() const {
return layers[0].v->type;
}
uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
uint32_t result = 0;
@ -1189,6 +1298,47 @@ ggml_tensor * llama_kv_cache::build_input_v_idxs(ggml_context * ctx, const llama
return v_idxs;
}
ggml_tensor * llama_kv_cache::build_input_k_rot(ggml_context * ctx) const {
ggml_tensor * res = nullptr;
if (attn_rot_k) {
int nrot = 64;
// TODO: investigate if using the smallest rotation matrix is beneficial also for K (similar as for V)
// ref: https://github.com/ggml-org/llama.cpp/pull/21038#issuecomment-4141323088
do {
nrot *= 2;
} while (hparams.n_embd_head_k() % nrot == 0);
nrot /= 2;
res = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nrot, nrot);
ggml_set_input(res);
ggml_set_name(res, "attn_inp_k_rot");
}
return res;
}
ggml_tensor * llama_kv_cache::build_input_v_rot(ggml_context * ctx) const {
ggml_tensor * res = nullptr;
if (attn_rot_v) {
int nrot = 64;
// using smaller rotation matrices for V seems beneficial
// ref: https://github.com/ggml-org/llama.cpp/pull/21038#issuecomment-4146397570
//do {
// nrot *= 2;
//} while (hparams.n_embd_head_v() % nrot == 0);
//nrot /= 2;
res = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nrot, nrot);
ggml_set_input(res);
ggml_set_name(res, "attn_inp_v_rot");
}
return res;
}
void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
const uint32_t n_tokens = ubatch->n_tokens;
GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
@ -1507,6 +1657,24 @@ void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch
}
}
void llama_kv_cache::set_input_k_rot(ggml_tensor * dst) const {
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
const auto n_rot = dst->ne[0];
GGML_ASSERT(attn_rot_hadamard.count(dst->ne[0]));
memcpy(dst->data, attn_rot_hadamard.at(n_rot).data(), ggml_nbytes(dst));
}
void llama_kv_cache::set_input_v_rot(ggml_tensor * dst) const {
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
const auto n_rot = dst->ne[0];
GGML_ASSERT(attn_rot_hadamard.count(dst->ne[0]));
memcpy(dst->data, attn_rot_hadamard.at(n_rot).data(), ggml_nbytes(dst));
}
size_t llama_kv_cache::total_size() const {
size_t size = 0;
@ -1542,6 +1710,7 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
ggml_context * ctx,
ggml_tensor * cur,
ggml_tensor * shift,
ggml_tensor * rot,
ggml_tensor * factors,
float freq_base,
float freq_scale,
@ -1567,10 +1736,16 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
// dequantize to f32 -> RoPE -> quantize back
tmp = ggml_cast(ctx, cur, GGML_TYPE_F32);
// rotate back
tmp = ggml_mul_mat_aux(ctx, tmp, rot);
tmp = ggml_rope_ext(ctx, tmp,
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
// rotate fwd
tmp = ggml_mul_mat_aux(ctx, tmp, rot);
tmp = ggml_cpy(ctx, tmp, cur);
} else {
// we rotate only the first n_rot dimensions
@ -1591,6 +1766,9 @@ public:
ggml_tensor * k_shift; // I32 [kv_size*n_stream]
// note: assumes k_rot^2 == I
ggml_tensor * k_rot = nullptr;
const llama_kv_cache * kv_self;
};
@ -1600,6 +1778,10 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
if (k_shift) {
kv_self->set_input_k_shift(k_shift);
}
if (k_rot) {
kv_self->set_input_k_rot(k_rot);
}
}
ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
@ -1611,6 +1793,8 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
ggml_set_input(inp->k_shift);
inp->k_rot = build_input_k_rot(ctx);
const auto & cparams = lctx->get_cparams();
for (const auto & layer : layers) {
@ -1635,7 +1819,7 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
ggml_row_size(layer.k->type, n_embd_k_gqa),
ggml_row_size(layer.k->type, n_embd_nope));
ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, il);
ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, inp->k_rot, rope_factors, freq_base_l, freq_scale_l, il);
ggml_build_forward_expand(gf, cur);
}
@ -2239,6 +2423,14 @@ uint32_t llama_kv_cache_context::get_n_kv() const {
return n_kv;
}
ggml_type llama_kv_cache_context::type_k() const {
return kv->type_k();
}
ggml_type llama_kv_cache_context::type_v() const {
return kv->type_v();
}
ggml_tensor * llama_kv_cache_context::get_k(ggml_context * ctx, int32_t il) const {
return kv->get_k(ctx, il, n_kv, sinfos[i_cur]);
}
@ -2263,6 +2455,14 @@ ggml_tensor * llama_kv_cache_context::build_input_v_idxs(ggml_context * ctx, con
return kv->build_input_v_idxs(ctx, ubatch);
}
ggml_tensor * llama_kv_cache_context::build_input_k_rot(ggml_context * ctx) const {
return kv->build_input_k_rot(ctx);
}
ggml_tensor * llama_kv_cache_context::build_input_v_rot(ggml_context * ctx) const {
return kv->build_input_v_rot(ctx);
}
void llama_kv_cache_context::set_input_k_shift(ggml_tensor * dst) const {
kv->set_input_k_shift(dst);
}
@ -2282,3 +2482,11 @@ void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ub
void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
kv->set_input_pos_bucket(dst, ubatch);
}
void llama_kv_cache_context::set_input_k_rot(ggml_tensor * dst) const {
kv->set_input_k_rot(dst);
}
void llama_kv_cache_context::set_input_v_rot(ggml_tensor * dst) const {
kv->set_input_v_rot(dst);
}

View File

@ -152,6 +152,9 @@ public:
bool get_has_shift() const;
ggml_type type_k() const;
ggml_type type_v() const;
//
// graph_build API
//
@ -191,6 +194,9 @@ public:
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
ggml_tensor * build_input_k_rot(ggml_context * ctx) const;
ggml_tensor * build_input_v_rot(ggml_context * ctx) const;
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
@ -199,6 +205,9 @@ public:
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
void set_input_k_rot(ggml_tensor * dst) const;
void set_input_v_rot(ggml_tensor * dst) const;
private:
const llama_model & model;
const llama_hparams & hparams;
@ -226,6 +235,13 @@ private:
// SWA
const uint32_t n_swa = 0;
// env: LLAMA_ATTN_ROT_DISABLE
bool attn_rot_k = false;
bool attn_rot_v = false;
// pre-computed hadamard martrices
std::unordered_map<int64_t, std::vector<float>> attn_rot_hadamard;
// env: LLAMA_KV_CACHE_DEBUG
int debug = 0;
@ -262,6 +278,7 @@ private:
ggml_context * ctx,
ggml_tensor * cur,
ggml_tensor * shift,
ggml_tensor * rot,
ggml_tensor * factors,
float freq_base,
float freq_scale,
@ -328,6 +345,9 @@ public:
uint32_t get_n_kv() const;
ggml_type type_k() const;
ggml_type type_v() const;
// get views of the current state of the cache
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
@ -347,6 +367,9 @@ public:
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
ggml_tensor * build_input_k_rot(ggml_context * ctx) const;
ggml_tensor * build_input_v_rot(ggml_context * ctx) const;
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
@ -354,6 +377,9 @@ public:
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
void set_input_k_rot(ggml_tensor * dst) const;
void set_input_v_rot(ggml_tensor * dst) const;
private:
llama_memory_status status;

View File

@ -73,9 +73,9 @@ llama_memory_context_ptr llama_memory_hybrid_iswa::init_batch(llama_batch_allocr
// if all tokens are output, split by sequence
ubatch = balloc.split_seq(n_ubatch);
} else {
// TODO: non-sequential equal split can be done if using unified KV cache
// for simplicity, we always use sequential equal split for now
ubatch = balloc.split_equal(n_ubatch, true);
// Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice)
const bool unified = (mem_attn->get_base()->get_n_stream() == 1);
ubatch = balloc.split_equal(n_ubatch, !unified);
}
if (ubatch.n_tokens == 0) {

View File

@ -73,9 +73,9 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
// if all tokens are output, split by sequence
ubatch = balloc.split_seq(n_ubatch);
} else {
// TODO: non-sequential equal split can be done if using unified KV cache
// for simplicity, we always use sequential equal split for now
ubatch = balloc.split_equal(n_ubatch, true);
// Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice)
const bool unified = (mem_attn->get_n_stream() == 1);
ubatch = balloc.split_equal(n_ubatch, !unified);
}
if (ubatch.n_tokens == 0) {

View File

@ -84,7 +84,6 @@ static std::string remap_imatrix(const std::string & orig_name, const std::map<i
for (const auto & p : mapped) {
if (p.second == blk) {
LLAMA_LOG_DEBUG("(blk.%d imatrix) ", p.first);
return new_name.replace(match.position(1), match.length(1), std::to_string(p.first));
}
}
@ -188,10 +187,9 @@ struct quantize_state_impl {
model(model), params(params)
{
// compile regex patterns once - they are expensive
if (params->tensor_types) {
const auto & tensor_types = *static_cast<const std::vector<tensor_type_option> *>(params->tensor_types);
for (const auto & [tname, qtype] : tensor_types) {
tensor_type_patterns.emplace_back(std::regex(tname), qtype);
if (params->tt_overrides) {
for (const auto * p = params->tt_overrides; p->pattern != nullptr; p++) {
tensor_type_patterns.emplace_back(std::regex(p->pattern), p->type);
}
}
}
@ -857,12 +855,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
constexpr bool use_mmap = false;
#endif
llama_model_kv_override * kv_overrides = nullptr;
if (params->kv_overrides) {
auto * v = (std::vector<llama_model_kv_override>*)params->kv_overrides;
kv_overrides = v->data();
}
const llama_model_kv_override * kv_overrides = params->kv_overrides;
std::vector<std::string> splits = {};
llama_model_loader ml(/*metadata*/ nullptr, /*set_tensor_data*/ nullptr, /*set_tensor_data_ud*/ nullptr,
fname_inp, splits, /*file*/ nullptr, use_mmap, /*use_direct_io*/ false, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr);
@ -879,9 +872,13 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
if (params->only_copy) {
ftype = ml.ftype;
}
std::unordered_map<std::string, std::vector<float>> i_data;
const std::unordered_map<std::string, std::vector<float>> * imatrix_data = nullptr;
if (params->imatrix) {
imatrix_data = static_cast<const std::unordered_map<std::string, std::vector<float>>*>(params->imatrix);
for (const llama_model_imatrix_data * p = params->imatrix; p->name != nullptr; p++) {
i_data.emplace(p->name, std::vector<float>(p->data, p->data + p->size));
}
imatrix_data = & i_data;
if (imatrix_data) {
LLAMA_LOG_INFO("\n%s: have importance matrix data with %d entries\n",
__func__, (int)imatrix_data->size());
@ -902,7 +899,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
std::vector<int> prune_list = {};
if (params->prune_layers) {
prune_list = *static_cast<const std::vector<int> *>(params->prune_layers);
for (const int32_t * p = params->prune_layers; * p != -1; p++) {
prune_list.push_back(* p);
}
}
// copy the KV pairs from the input file
@ -916,20 +915,18 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str());
if (params->kv_overrides) {
const std::vector<llama_model_kv_override> & overrides = *(const std::vector<llama_model_kv_override> *)params->kv_overrides;
for (const auto & o : overrides) {
if (o.key[0] == 0) break;
if (o.tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) {
gguf_set_val_f32(ctx_out.get(), o.key, o.val_f64);
} else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_INT) {
for (const llama_model_kv_override * o = params->kv_overrides; o->key[0] != 0; ++o) {
if (o->tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) {
gguf_set_val_f32(ctx_out.get(), o->key, o->val_f64);
} else if (o->tag == LLAMA_KV_OVERRIDE_TYPE_INT) {
// Setting type to UINT32. See https://github.com/ggml-org/llama.cpp/pull/14182 for context
gguf_set_val_u32(ctx_out.get(), o.key, (uint32_t)std::abs(o.val_i64));
} else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) {
gguf_set_val_bool(ctx_out.get(), o.key, o.val_bool);
} else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_STR) {
gguf_set_val_str(ctx_out.get(), o.key, o.val_str);
gguf_set_val_u32(ctx_out.get(), o->key, (uint32_t)std::abs(o->val_i64));
} else if (o->tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) {
gguf_set_val_bool(ctx_out.get(), o->key, o->val_bool);
} else if (o->tag == LLAMA_KV_OVERRIDE_TYPE_STR) {
gguf_set_val_str(ctx_out.get(), o->key, o->val_str);
} else {
LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o.key);
LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o->key);
}
}
}

View File

@ -2712,6 +2712,67 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
.run();
}
// LFM2.5 tests - uses plain "List of tools: [...]" and bare [name(args)] without wrapper tokens
{
auto tst = peg_tester("models/templates/LFM2.5-Instruct.jinja", detailed_debug);
// Basic content only
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
// Single tool call without reasoning
tst.test("[special_function(arg1=1)]")
.tools({ special_function_tool })
.expect(message_assist_call)
.run();
// Tool call with string argument
tst.test("[get_time(city=\"XYZCITY\")]")
.tools({ get_time_tool })
.expect(message_with_tool_calls("get_time", "{\"city\":\"XYZCITY\"}"))
.run();
// Tool call with reasoning (enable_thinking=true)
tst.test("<think>I'm\nthinking</think>[special_function(arg1=1)]")
.enable_thinking(true)
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.tools({ special_function_tool })
.expect(message_assist_call_thoughts)
.run();
// Multiple tool calls (parallel)
tst.test("[special_function(arg1=1), special_function_with_opt(arg1=1, arg2=2)]")
.parallel_tool_calls(true)
.tools({
special_function_tool, special_function_tool_with_optional_param
})
.expect_tool_calls({
{ "special_function", R"({"arg1": 1})", {} },
{ "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} },
})
.run();
// Tool call with content before tool call
tst.test("Let me check the time.[get_time(city=\"Paris\")]")
.tools({ get_time_tool })
.expect(message_with_reasoning_content_and_multiple_tool_calls(
"", "Let me check the time.", { { "get_time", "{\"city\":\"Paris\"}" } }
))
.run();
// Partial tool call (streaming)
tst.test("[special_function(arg1=")
.tools({ special_function_tool })
.is_partial(true)
.expect(simple_assist_msg("", "", "special_function", "{\"arg1\": "))
.run();
// Tool call with empty arguments
tst.test("[empty_args()]")
.tools({ empty_args_tool })
.expect(simple_assist_msg("", "", "empty_args", "{}"))
.run();
}
// Apertus-8B-Instruct tests - FUNC_NAME_AS_KEY format
// Format: <|tools_prefix|>[{"function_name": {...arguments...}}]<|tools_suffix|>
{

View File

@ -365,6 +365,10 @@ int main(int argc, char ** argv) {
llama_backend_init();
llama_numa_init(params.numa);
if (!set_process_priority(params.cpuparams.priority)) {
LOG_WRN("%s: failed to set process priority\n", __func__);
}
// TODO: avoid using atexit() here by making `console` a singleton
console::init(params.simple_io, params.use_color);
atexit([]() { console::cleanup(); });

View File

@ -13,13 +13,10 @@
#include <unordered_map>
#include <map>
#include <fstream>
#include <cmath>
#include <cctype>
#include <algorithm>
#include <filesystem>
// result of parsing --tensor-type option
// (changes to this struct must be reflected in src/llama-quant.cpp)
// changes to this struct must also be reflected in src/llama-quant.cpp
struct tensor_type_option {
std::string name;
ggml_type type = GGML_TYPE_COUNT;
@ -491,7 +488,6 @@ static bool parse_layer_prune(const char * data, std::vector<int> & prune_layers
int main(int argc, char ** argv) {
std::setlocale(LC_NUMERIC, "C");
if (argc < 3) {
usage(argv[0]);
}
@ -584,8 +580,16 @@ int main(int argc, char ** argv) {
std::vector<std::string> imatrix_datasets;
std::unordered_map<std::string, std::vector<float>> imatrix_data;
int m_last_call = prepare_imatrix(imatrix_file, imatrix_datasets, included_weights, excluded_weights, imatrix_data);
std::vector<llama_model_imatrix_data> i_data;
std::vector<llama_model_tensor_override> t_override;
if (!imatrix_data.empty()) {
params.imatrix = &imatrix_data;
i_data.reserve(imatrix_data.size() + 1);
for (const auto & kv : imatrix_data) {
i_data.push_back({kv.first.c_str(), kv.second.data(), kv.second.size()});
}
i_data.push_back({nullptr, nullptr, 0}); // array terminator
params.imatrix = i_data.data();
{
llama_model_kv_override kvo;
std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_FILE);
@ -603,7 +607,6 @@ int main(int argc, char ** argv) {
kvo.val_str[127] = '\0';
kv_overrides.emplace_back(std::move(kvo));
}
{
llama_model_kv_override kvo;
std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES);
@ -611,7 +614,6 @@ int main(int argc, char ** argv) {
kvo.val_i64 = imatrix_data.size();
kv_overrides.emplace_back(std::move(kvo));
}
if (m_last_call > 0) {
llama_model_kv_override kvo;
std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_N_CHUNKS);
@ -623,13 +625,19 @@ int main(int argc, char ** argv) {
if (!kv_overrides.empty()) {
kv_overrides.emplace_back();
kv_overrides.back().key[0] = 0;
params.kv_overrides = &kv_overrides;
params.kv_overrides = kv_overrides.data();
}
if (!tensor_type_opts.empty()) {
params.tensor_types = &tensor_type_opts;
t_override.reserve(tensor_type_opts.size() + 1);
for (const auto & tt : tensor_type_opts) {
t_override.push_back({tt.name.c_str(), tt.type});
}
t_override.push_back({nullptr, GGML_TYPE_COUNT}); // array terminator
params.tt_overrides = t_override.data();
}
if (!prune_layers.empty()) {
params.prune_layers = &prune_layers;
prune_layers.push_back(-1); // array terminator
params.prune_layers = prune_layers.data();
}
llama_backend_init();

View File

@ -108,6 +108,10 @@ int main(int argc, char ** argv) {
llama_backend_init();
llama_numa_init(params.numa);
if (!set_process_priority(params.cpuparams.priority)) {
LOG_WRN("%s: failed to set process priority\n", __func__);
}
LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency());
LOG_INF("\n");
LOG_INF("%s\n", common_params_get_system_info(params).c_str());