Merge branch 'webgpu/conv2d'

This commit is contained in:
Constannnnnt 2026-04-15 13:43:40 -04:00
commit 23cc93fc1f
42 changed files with 1977 additions and 410 deletions

View File

@ -93,4 +93,5 @@ jobs:
export GGML_VK_DISABLE_F16=1
export GGML_VK_DISABLE_COOPMAT=1
# This is using llvmpipe and runs slower than other backends
ctest -L main --verbose --timeout 4800
# test-backend-ops is too slow on llvmpipe, skip it
ctest -L main -E test-backend-ops --verbose --timeout 900

View File

@ -198,10 +198,19 @@ common_peg_parser analyze_tools::build_tool_parser_json_native(parser_build_cont
args_field = format.function_field + "." + args_field;
}
auto tools_parser = p.standard_json_tools(
format.section_start, format.section_end, inputs.tools, inputs.parallel_tool_calls,
inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED, name_field, args_field, format.tools_array_wrapped,
format.fun_name_is_key, format.id_field, format.gen_id_field, format.parameter_order);
auto tools_parser = p.eps();
if (format.section_start.empty() && !format.per_call_start.empty()) {
auto single_tool_parser = p.standard_json_tools(
format.per_call_start, format.per_call_end, inputs.tools, inputs.parallel_tool_calls,
inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED, name_field, args_field, format.tools_array_wrapped,
format.fun_name_is_key, format.id_field, format.gen_id_field, format.parameter_order);
tools_parser = p.trigger_rule("tool-calls", p.one_or_more(single_tool_parser + p.space()));
} else {
tools_parser = p.standard_json_tools(
format.section_start, format.section_end, inputs.tools, inputs.parallel_tool_calls,
inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED, name_field, args_field, format.tools_array_wrapped,
format.fun_name_is_key, format.id_field, format.gen_id_field, format.parameter_order);
}
// Handle content wrappers if present
if (ctx.content && ctx.content->is_always_wrapped()) {

View File

@ -308,19 +308,23 @@ struct analyze_tools : analyze_base {
private:
// Extract tool calling 'haystack' for further analysis and delegate further analysis based on format
void analyze_tool_calls(const analyze_reasoning & reasoning);
void analyze_tool_calls(const analyze_reasoning & reasoning, bool supports_parallel_tool_calls);
// Analyze format based on position of function and argument name in needle
void analyze_tool_call_format(const std::string & haystack,
const std::string & fun_name_needle,
const std::string & arg_name_needle,
const analyze_reasoning & reasoning);
const analyze_reasoning & reasoning,
bool supports_parallel_tool_calls);
// Analyze specifics of JSON native format (entire tool call is a JSON object)
void analyze_tool_call_format_json_native(const std::string & clean_haystack,
const std::string & fun_name_needle,
const std::string & arg_name_needle);
// Check if parallel calls in JSON native format array wrapped or tag wrapped
void analyze_json_native_parallel_calls();
// Analyze specifics of non-JSON native format (tags for function name or for function name and arguments)
void analyze_tool_call_format_non_json(const std::string & clean_haystack,
const std::string & fun_name_needle);

View File

@ -558,7 +558,7 @@ analyze_tools::analyze_tools(const common_chat_template & tmpl,
: analyze_base(tmpl) {
LOG_DBG(ANSI_ORANGE "Phase 3: Tool call analysis\n" ANSI_RESET);
analyze_tool_calls(reasoning);
analyze_tool_calls(reasoning, caps.supports_parallel_tool_calls);
if (format.mode != tool_format::NONE && format.mode != tool_format::JSON_NATIVE) {
if (caps.supports_parallel_tool_calls) {
@ -577,7 +577,7 @@ analyze_tools::analyze_tools(const common_chat_template & tmpl,
}
}
void analyze_tools::analyze_tool_calls(const analyze_reasoning & reasoning) {
void analyze_tools::analyze_tool_calls(const analyze_reasoning & reasoning, bool supports_parallel_tool_calls) {
json assistant_no_tools = json{
{ "role", "assistant" },
{ "content", ASSISTANT_MSG }
@ -611,13 +611,14 @@ void analyze_tools::analyze_tool_calls(const analyze_reasoning & reasoning) {
return;
}
analyze_tool_call_format(tool_section, FUN_FIRST, ARG_FIRST, reasoning);
analyze_tool_call_format(tool_section, FUN_FIRST, ARG_FIRST, reasoning, supports_parallel_tool_calls);
}
void analyze_tools::analyze_tool_call_format(const std::string & haystack,
const std::string & fun_name_needle,
const std::string & arg_name_needle,
const analyze_reasoning & reasoning) {
const analyze_reasoning & reasoning,
bool supports_parallel_tool_calls) {
if (fun_name_needle.empty() || arg_name_needle.empty() || haystack.empty()) {
return;
}
@ -660,6 +661,9 @@ void analyze_tools::analyze_tool_call_format(const std::string & haystack,
if (format.mode == tool_format::JSON_NATIVE) {
analyze_tool_call_format_json_native(clean_haystack, fun_name_needle, arg_name_needle);
if (supports_parallel_tool_calls) {
analyze_json_native_parallel_calls();
}
} else {
analyze_tool_call_format_non_json(clean_haystack, fun_name_needle);
}
@ -668,6 +672,42 @@ void analyze_tools::analyze_tool_call_format(const std::string & haystack,
format.per_call_end = trim_whitespace(format.per_call_end);
}
void analyze_tools::analyze_json_native_parallel_calls() {
json assistant_one_tool = json{
{ "role", "assistant" },
{ "content", "" },
{ "tool_calls", json::array({ first_tool_call }) }
};
json assistant_two_tools = json{
{ "role", "assistant" },
{ "content", "" },
{ "tool_calls", json::array({ first_tool_call, second_tool_call }) }
};
template_params params;
params.messages = json::array({ user_msg, assistant_one_tool });
params.tools = tools;
params.add_generation_prompt = false;
params.enable_thinking = true;
auto comparison = compare_variants(
*tmpl, params, [&](template_params & p) { p.messages = json::array({ user_msg, assistant_two_tools }); });
if (!comparison) {
LOG_DBG(ANSI_ORANGE "%s: Template application failed\n" ANSI_RESET, __func__);
return;
}
std::string & second_call = comparison->diff.right;
if (!format.section_start.empty() && second_call.find(format.section_start) != std::string::npos) {
format.per_call_start = format.section_start;
format.per_call_end = format.section_end;
format.section_start.clear();
format.section_end.clear();
}
}
void analyze_tools::analyze_tool_call_format_json_native(const std::string & clean_haystack,
const std::string & fun_name_needle,
const std::string & arg_name_needle) {

View File

@ -676,7 +676,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_nested_keys(
ordered_json params = function.contains("parameters") ? function.at("parameters") : ordered_json::object();
auto nested_name = literal("\"" + nested_name_field + "\"") + space() + literal(":") + space() +
literal("\"") + tool_name(literal(name)) + literal("\"");
atomic(literal("\"") + tool_name(literal(name)) + literal("\""));
auto nested_args = literal("\"" + nested_args_field + "\"") + space() + literal(":") + space() +
tool_args(schema(json(), "tool-" + name + "-schema", params));
@ -744,7 +744,7 @@ common_peg_parser common_chat_peg_builder::build_json_tools_flat_keys(
ordered_json params = function.contains("parameters") ? function.at("parameters") : ordered_json::object();
auto tool_name_ = name_key_parser + space() + literal(":") + space() +
literal("\"") + tool_name(literal(name)) + literal("\"");
atomic(literal("\"") + tool_name(literal(name)) + literal("\""));
auto tool_args_ = args_key_parser + space() + literal(":") + space() +
tool_args(schema(json(), "tool-" + name + "-schema", params));

View File

@ -281,6 +281,12 @@ Use `GGML_CUDA_FORCE_CUBLAS_COMPUTE_16F` environment variable to force use FP16
The environment variable `GGML_CUDA_ENABLE_UNIFIED_MEMORY=1` can be used to enable unified memory in Linux. This allows swapping to system RAM instead of crashing when the GPU VRAM is exhausted. In Windows this setting is available in the NVIDIA control panel as `System Memory Fallback`.
### Peer Access
The environment variable `GGML_CUDA_P2P` can be set to enable peer-to-peer access between multiple GPUs, allowing them to transfer data directly rather than to go through system memory.
Requires driver support (usually restricted to workstation/datacenter GPUs).
May cause crashes or corrupted outputs for some motherboards and BIOS settings (e.g. IOMMU).
### Performance Tuning
The following compilation options are also available to tweak performance:

View File

@ -130,6 +130,23 @@ Note:
- Adding a model-specific API or CLI is an anti-pattern in `libmtmd`. The goal of `libmtmd` is to provide an easy-to-use, model-agnostic library for multimodal pipeline.
- In most cases, `llama-mtmd-cli` should not be modified. If a model requires a specific prompt, either let the user provide it or bake it into the Jinja chat template.
## Tips and tricks
### Working with ggml_rope_ext
PyTorch implementations usually prefer explicitly calculating `freq_cis`/`sin`/`cos` components. However, in llama.cpp, most RoPE operations can be handled via `ggml_rope_ext`, which does not require a sin/cos matrix. This saves memory while allowing the GGML RoPE kernel to be fused with other ops.
However, since `ggml_rope_ext` only provides a subset of the RoPE implementations that models use, converting models from PyTorch to llama.cpp may require some creative adaptations.
For more information about `ggml_rope_ext`, please refer to the in-code documentation in `ggml.h`.
Examples:
- `libmtmd` implements 2D RoPE with `GGML_ROPE_TYPE_NORMAL` ordering by splitting the input tensor in half, applying `ggml_rope_ext` separately to each half, then joining them back together using `ggml_concat`.
- The [Kimi-K2.5](https://github.com/ggml-org/llama.cpp/pull/19170) vision encoder uses vision RoPE with interleaved frequencies. The weights must be permuted during conversion in order to reuse the `build_rope_2d()` function.
- [Gemma 4](https://github.com/ggml-org/llama.cpp/pull/21309) uses "proportional" RoPE. We employ a trick where `rope_freqs` is set to a very large value in the last dimensions to prevent those dimensions from being rotated. See the `Gemma4Model` class in `convert_hf_to_gguf.py`.
- Some models require scaling the input position. For example, `[0, 1, 2, ...]` becomes `[0, 0.5, 1, ...]`. In this case, you can provide the scaling via `freq_scale = 0.5f`.
- Some models use learned RoPE frequencies instead of relying on `powf(freq_base, -2.0 * i / n_dims)`. In this case, you can provide the learned frequencies via the `rope_freqs` tensor (corresponding to the `c` argument in `ggml_rope_ext`), then set `freq_base = 1.0f`. An important note is that `rope_freqs` in GGML is the **inverse** (`theta = pos[i] / rope_freqs`), so you may need to invert `rope_freqs` during conversion.
## GGUF specification
https://github.com/ggml-org/ggml/blob/master/docs/gguf.md

View File

@ -602,8 +602,8 @@ int main(int argc, char ** argv) {
int n_input = input_tokens.size();
if (n_input >= params.n_ctx) {
LOG_ERR("error: input too long (%d tokens), max context is %d\n", n_input, params.n_ctx);
if (static_cast<uint32_t>(n_input) >= llama_n_ctx(ctx)) {
LOG_ERR("error: input too long (%d tokens), max context is %d\n", n_input, llama_n_ctx(ctx));
llama_free(ctx);
llama_model_free(model);
return 1;

View File

@ -202,8 +202,11 @@ extern "C" {
// Common functions that may be obtained using ggml_backend_reg_get_proc_address
// AllReduce operation for tensor parallelism (meta backend)
typedef bool (*ggml_backend_allreduce_tensor_t)(ggml_backend_t * backends, struct ggml_tensor ** tensors, size_t n_backends);
// Context management and operations for faster communication between backends, used for tensor parallelism (meta backend)
typedef void * (*ggml_backend_comm_init_t)(ggml_backend_t * backends, size_t n_backends);
typedef void (*ggml_backend_comm_free_t)(void * comm_ctx);
typedef bool (*ggml_backend_comm_allreduce_tensor_t)(void * comm_ctx, struct ggml_tensor ** tensors);
// Split buffer type for tensor parallelism (old)
typedef ggml_backend_buffer_type_t (*ggml_backend_split_buffer_type_t)(int main_device, const float * tensor_split);
// Set the number of threads for the backend

View File

@ -6,9 +6,9 @@
extern "C" {
#endif
#define RPC_PROTO_MAJOR_VERSION 3
#define RPC_PROTO_MINOR_VERSION 6
#define RPC_PROTO_PATCH_VERSION 1
#define RPC_PROTO_MAJOR_VERSION 4
#define RPC_PROTO_MINOR_VERSION 0
#define RPC_PROTO_PATCH_VERSION 0
#ifdef __cplusplus
static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT has changed - update RPC_PROTO_PATCH_VERSION");

View File

@ -1773,8 +1773,32 @@ extern "C" {
int n_dims,
int mode);
// custom RoPE
// RoPE operations with extended options
// a is the input tensor to apply RoPE to, shape [n_embd, n_head, n_token]
// b is an int32 vector with size n_token
// c is freq factors (e.g. phi3-128k), (optional)
// mode can be GGML_ROPE_TYPE_NORMAL or NEOX; for MROPE and VISION mode, use ggml_rope_multi
//
// pseudo-code for computing theta:
// for i in [0, n_dims/2):
// theta[i] = b[i] * powf(freq_base, -2.0 * i / n_dims);
// theta[i] = theta[i] / c[i]; # if c is provided, divide theta by c
// theta[i] = rope_yarn(theta[i], ...); # note: theta = theta * freq_scale is applied here
//
// other params are used by YaRN RoPE scaling, these default values will disable YaRN:
// freq_scale = 1.0f
// ext_factor = 0.0f
// attn_factor = 1.0f
// beta_fast = 0.0f
// beta_slow = 0.0f
//
// example:
// (marking: c = cos, s = sin, 0 = unrotated)
// given a single head with size = 8 --> [00000000]
// GGML_ROPE_TYPE_NORMAL n_dims = 4 --> [cscs0000]
// GGML_ROPE_TYPE_NORMAL n_dims = 8 --> [cscscscs]
// GGML_ROPE_TYPE_NEOX n_dims = 4 --> [ccss0000]
// GGML_ROPE_TYPE_NEOX n_dims = 8 --> [ccccssss]
GGML_API struct ggml_tensor * ggml_rope_ext(
struct ggml_context * ctx,
struct ggml_tensor * a,
@ -1790,6 +1814,36 @@ extern "C" {
float beta_fast,
float beta_slow);
// multi-dimensional RoPE, for Qwen-VL and similar vision models
// mode can be either VISION, MROPE, IMROPE, cannot be combined with NORMAL or NEOX
// sections specify how many dimensions to rotate in each section:
// section length is equivalent to number of cos/sin pairs, NOT the number of dims
// (i.e. sum of 4 sections are expected to be n_dims/2)
// last sections can be 0, means ignored
// all other options are identical to ggml_rope_ext
//
// important note:
// - NEOX ordering is automatically applied and cannot be disabled for MROPE and VISION
// if you need normal ordering, there are 2 methods:
// (1) split the tensor manually using ggml_view
// (2) permute the weight upon conversion
// - for VISION, n_dims must be head_size/2
//
// example M-RoPE:
// given sections = [t=4, y=2, x=2, 0]
// given a single head with size = 18 --> [000000000000000000]
// GGML_ROPE_TYPE_MROPE n_dims = 16 --> [ttttyyxxttttyyxx00] (cos/sin are applied in NEOX ordering)
// GGML_ROPE_TYPE_IMROPE n_dims = 16 --> [ttyxttyxttyxttyx00] (interleaved M-RoPE, still NEOX ordering)
// note: the theta for each dim is computed the same way as ggml_rope_ext, no matter the section
// in other words, idx used for theta: [0123456789... until n_dims/2], not reset for each section
//
// example vision RoPE:
// given sections = [y=4, x=4, 0, 0] (last 2 sections are ignored)
// given a single head with size = 8 --> [00000000]
// GGML_ROPE_TYPE_VISION n_dims = 4 --> [yyyyxxxx]
// other values of n_dims are untested and is undefined behavior
// note: unlike MROPE, the theta for each dim is computed differently for each section
// in other words, idx used for theta: [0123] for y section, then [0123] for x section
GGML_API struct ggml_tensor * ggml_rope_multi(
struct ggml_context * ctx,
struct ggml_tensor * a,

View File

@ -1419,22 +1419,48 @@ struct ggml_backend_meta_context {
size_t max_tmp_size = 0;
size_t max_subgraphs = 0;
void * comm_ctx = nullptr;
ggml_backend_comm_allreduce_tensor_t comm_allreduce = nullptr;
ggml_backend_meta_context(ggml_backend_dev_t meta_dev, const char * params) {
const size_t n_devs = ggml_backend_meta_dev_n_devs(meta_dev);
name = "Meta(";
std::vector<ggml_backend_t> simple_backends;
backend_configs.reserve(n_devs);
simple_backends.reserve(n_devs);
for (size_t i = 0; i < n_devs; i++) {
ggml_backend_dev_t simple_dev = ggml_backend_meta_dev_simple_dev(meta_dev, i);
if (i > 0) {
name += ",";
}
name += ggml_backend_dev_name(simple_dev);
backend_configs.emplace_back(ggml_backend_dev_init(simple_dev, params));
simple_backends.push_back(ggml_backend_dev_init(simple_dev, params));
backend_configs.emplace_back(simple_backends.back());
}
name += ")";
if (n_devs > 1) {
ggml_backend_comm_init_t comm_init = (ggml_backend_comm_init_t) ggml_backend_reg_get_proc_address(
ggml_backend_dev_backend_reg(ggml_backend_get_device(simple_backends[0])), "ggml_backend_comm_init");
if (comm_init != nullptr) {
comm_ctx = comm_init(simple_backends.data(), simple_backends.size());
}
}
if (comm_ctx != nullptr) {
comm_allreduce = (ggml_backend_comm_allreduce_tensor_t)
ggml_backend_reg_get_proc_address(ggml_backend_dev_backend_reg(
ggml_backend_get_device(simple_backends[0])), "ggml_backend_comm_allreduce_tensor");
GGML_ASSERT(comm_allreduce != nullptr);
}
}
~ggml_backend_meta_context() {
if (comm_ctx != nullptr) {
ggml_backend_comm_free_t comm_free = (ggml_backend_comm_free_t) ggml_backend_reg_get_proc_address(
ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_configs[0].backend)), "ggml_backend_comm_free");
GGML_ASSERT(comm_free != nullptr);
comm_free(comm_ctx);
}
for (auto & bc : backend_configs) {
ggml_backend_free(bc.backend);
}
@ -1845,20 +1871,15 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend,
if (n_backends > 1 && i < n_subgraphs - 1) {
bool backend_allreduce_success = false;
ggml_backend_allreduce_tensor_t allreduce_tensor = (ggml_backend_allreduce_tensor_t) ggml_backend_reg_get_proc_address(
ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_ctx->backend_configs[0].backend)), "ggml_backend_allreduce_tensor");
if (allreduce_tensor) {
std::vector<ggml_backend_t> backends;
backends.reserve(n_backends);
if (backend_ctx->comm_ctx) {
std::vector<ggml_tensor *> nodes;
nodes.reserve(n_backends);
for (size_t j = 0; j < n_backends; j++) {
auto & bcj = backend_ctx->backend_configs[j];
backends.push_back(bcj.backend);
ggml_cgraph * cgraph_ij = bcj.cgraphs[i].cgraph_main;
nodes.push_back(cgraph_ij->nodes[cgraph_ij->n_nodes-1]);
}
backend_allreduce_success = allreduce_tensor(backends.data(), nodes.data(), n_backends);
backend_allreduce_success = backend_ctx->comm_allreduce(backend_ctx->comm_ctx, nodes.data());
}
if (!backend_allreduce_success) {

View File

@ -924,6 +924,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_F16> {
static constexpr int qr = 1;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q1_0> {
static constexpr int qk = QK1_0;
static constexpr int qr = QR1_0;
static constexpr int qi = QI1_0;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q4_0> {
static constexpr int qk = QK4_0;
@ -1092,10 +1099,6 @@ struct ggml_cuda_device_info {
cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {};
std::array<float, GGML_CUDA_MAX_DEVICES> default_tensor_split = {};
#ifdef GGML_USE_NCCL
ncclComm_t comms[GGML_CUDA_MAX_DEVICES];
#endif // GGML_USE_NCCL
};
const ggml_cuda_device_info & ggml_cuda_info();

View File

@ -711,6 +711,8 @@ to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) {
to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_Q1_0:
return dequantize_block_cont_cuda<QK1_0, QR1_0, dequantize_q1_0>;
case GGML_TYPE_Q4_0:
return dequantize_row_q4_0_cuda;
case GGML_TYPE_Q4_1:
@ -767,6 +769,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_Q1_0:
return dequantize_block_cont_cuda<QK1_0, QR1_0, dequantize_q1_0>;
case GGML_TYPE_Q4_0:
return dequantize_row_q4_0_cuda;
case GGML_TYPE_Q4_1:
@ -822,6 +826,8 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_F32:
return convert_unary_cuda<float>;
case GGML_TYPE_Q1_0:
return dequantize_block_cuda<QK1_0, QR1_0, dequantize_q1_0>;
case GGML_TYPE_Q4_0:
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
case GGML_TYPE_Q4_1:
@ -843,6 +849,8 @@ to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_F32:
return convert_unary_cuda<float, nv_bfloat16>;
case GGML_TYPE_Q1_0:
return dequantize_block_cuda<QK1_0, QR1_0, dequantize_q1_0>;
case GGML_TYPE_Q4_0:
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
case GGML_TYPE_Q4_1:
@ -864,6 +872,8 @@ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_F16:
return convert_unary_cuda<half, float>;
case GGML_TYPE_Q1_0:
return dequantize_block_cuda<QK1_0, QR1_0, dequantize_q1_0>;
case GGML_TYPE_Q4_0:
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
case GGML_TYPE_Q4_1:

View File

@ -1,5 +1,27 @@
#include "common.cuh"
static __device__ __forceinline__ void dequantize_q1_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
const block_q1_0 * x = (const block_q1_0 *) vx;
const float d = x[ib].d;
const int bit_index_0 = iqs;
const int bit_index_1 = iqs + 1;
const int byte_index_0 = bit_index_0 / 8;
const int bit_offset_0 = bit_index_0 % 8;
const int byte_index_1 = bit_index_1 / 8;
const int bit_offset_1 = bit_index_1 % 8;
// Extract bits: 1 = +d, 0 = -d (branchless)
const int bit_0 = (x[ib].qs[byte_index_0] >> bit_offset_0) & 1;
const int bit_1 = (x[ib].qs[byte_index_1] >> bit_offset_1) & 1;
v.x = (2*bit_0 - 1) * d;
v.y = (2*bit_1 - 1) * d;
}
static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
const block_q4_0 * x = (const block_q4_0 *) vx;

View File

@ -179,6 +179,10 @@ static void ggml_cuda_get_rows_switch_src0_type(
get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_Q1_0:
get_rows_cuda_q<QK1_0, QR1_0, dequantize_q1_0>(src0_d, src1_d, dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_Q4_0:
get_rows_cuda_q<QK4_0, QR4_0, dequantize_q4_0>(src0_d, src1_d, dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);

View File

@ -324,28 +324,22 @@ static ggml_cuda_device_info ggml_cuda_init() {
// configure logging to stdout
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
for (int id = 0; id < info.device_count; ++id) {
ggml_cuda_set_device(id);
for (int id_other = 0; id_other < info.device_count; ++id_other) {
if (id == id_other) {
continue;
}
int can_access_peer;
CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other));
if (can_access_peer) {
CUDA_CHECK(cudaDeviceEnablePeerAccess(id_other, 0));
if (getenv("GGML_CUDA_P2P") != nullptr) {
for (int id = 0; id < info.device_count; ++id) {
ggml_cuda_set_device(id);
for (int id_other = 0; id_other < info.device_count; ++id_other) {
if (id == id_other) {
continue;
}
int can_access_peer;
CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other));
if (can_access_peer) {
CUDA_CHECK(cudaDeviceEnablePeerAccess(id_other, 0));
}
}
}
}
#ifdef GGML_USE_NCCL
int dev_ids[GGML_CUDA_MAX_DEVICES];
for (int id = 0; id < info.device_count; ++id) {
dev_ids[id] = id;
}
NCCL_CHECK(ncclCommInitAll(info.comms, info.device_count, dev_ids));
#endif // GGML_USE_NCCL
return info;
}
@ -1125,66 +1119,51 @@ static const ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_inte
/* .is_host = */ ggml_backend_cuda_split_buffer_type_is_host,
};
bool ggml_backend_cuda_allreduce_tensor(ggml_backend_t * backends, struct ggml_tensor ** tensors, size_t n_backends) {
#ifdef GGML_USE_NCCL
const int64_t ne = ggml_nelements(tensors[0]);
// FIXME the input of llm_graph_context::build_in_out_ids can produce a tensor with 0 elements if n_outputs == 0
// This then causes a crash in this function
if (ne == 0) {
return true;
}
for (size_t i = 0; i < n_backends; ++i) {
GGML_ASSERT(tensors[i] != nullptr);
GGML_ASSERT(ggml_nelements(tensors[i]) == ne);
GGML_ASSERT(ggml_is_contiguously_allocated(tensors[i]));
}
struct ggml_backend_cuda_comm_context {
std::vector<ggml_backend_t> backends;
std::vector<ncclComm_t> comms;
const ggml_cuda_device_info info = ggml_cuda_info();
// For small tensors, simply reduce them as FP32.
// The following heuristic for how "small" a tensor should be is based on RTX 4090s connected via 16x PCIe 4.0.
if ((n_backends <= 2 && ne < 32768) || (n_backends == 3 && ne < 131072) || (n_backends >= 4 && ne < 262144)) {
NCCL_CHECK(ncclGroupStart());
for (size_t i = 0; i < n_backends; ++i) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context;
NCCL_CHECK(ncclAllReduce(tensors[i]->data, tensors[i]->data, ne, ncclFloat, ncclSum, info.comms[cuda_ctx->device], cuda_ctx->stream()));
~ggml_backend_cuda_comm_context() {
for (ncclComm_t comm : comms) {
NCCL_CHECK(ncclCommDestroy(comm));
}
NCCL_CHECK(ncclGroupEnd());
return true;
}
};
#endif // GGML_USE_NCCL
// For large tensors it's faster to compress them to BF16 for the reduction:
to_bf16_cuda_t to_bf16 = ggml_get_to_bf16_cuda(GGML_TYPE_F32);
to_fp32_cuda_t to_fp32 = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
static void ggml_backend_cuda_comm_free(void * comm_ctx_v) {
#ifdef GGML_USE_NCCL
if (comm_ctx_v == nullptr) {
return;
}
ggml_backend_cuda_comm_context * comm_ctx = (ggml_backend_cuda_comm_context *) comm_ctx_v;
delete comm_ctx;
#else
GGML_UNUSED(comm_ctx_v);
#endif // GGML_USE_NCCL
}
ggml_cuda_pool_alloc<nv_bfloat16> tmp[GGML_CUDA_MAX_DEVICES];
for (size_t i = 0; i < n_backends; ++i) {
static void * ggml_backend_cuda_comm_init(ggml_backend_t * backends, size_t n_backends) {
#ifdef GGML_USE_NCCL
for (size_t i = 0; i < n_backends; i++) {
if (!ggml_backend_is_cuda(backends[i])) {
return nullptr;
}
}
ggml_backend_cuda_comm_context * ret = new ggml_backend_cuda_comm_context;
std::vector<int> dev_ids;
ret->backends.reserve(n_backends);
dev_ids.reserve(n_backends);
for (size_t i = 0; i < n_backends; i++) {
ret->backends.push_back(backends[i]);
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context;
tmp[i].pool = &cuda_ctx->pool();
tmp[i].alloc(ne);
ggml_cuda_set_device(i);
to_bf16(tensors[i]->data, tmp[i].get(), ne, cuda_ctx->stream());
CUDA_CHECK(cudaGetLastError());
dev_ids.push_back(cuda_ctx->device);
}
NCCL_CHECK(ncclGroupStart());
for (size_t i = 0; i < n_backends; ++i) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context;
NCCL_CHECK(ncclAllReduce(tmp[i].get(), tmp[i].get(), ne, ncclBfloat16, ncclSum, info.comms[cuda_ctx->device], cuda_ctx->stream()));
}
NCCL_CHECK(ncclGroupEnd());
for (size_t i = 0; i < n_backends; ++i) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context;
ggml_cuda_set_device(i);
to_fp32(tmp[i].get(), (float *) tensors[i]->data, ne, cuda_ctx->stream());
CUDA_CHECK(cudaGetLastError());
}
return true;
ret->comms.resize(n_backends);
NCCL_CHECK(ncclCommInitAll(ret->comms.data(), n_backends, dev_ids.data()));
return ret;
#else
// If NCCL is installed it is used by default for optimal performance.
// However, NVIDIA does not distribute NCCL with CUDA so users may be unwittingly missing this package.
@ -1197,7 +1176,76 @@ bool ggml_backend_cuda_allreduce_tensor(ggml_backend_t * backends, struct ggml_t
warning_printed = true;
}
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
GGML_UNUSED_VARS(backends, tensors, n_backends);
GGML_UNUSED_VARS(backends, n_backends);
return nullptr;
#endif // GGML_USE_NCCL
}
static bool ggml_backend_cuda_comm_allreduce_tensor(void * comm_ctx_v, struct ggml_tensor ** tensors) {
#ifdef GGML_USE_NCCL
const int64_t ne = ggml_nelements(tensors[0]);
// FIXME the input of llm_graph_context::build_in_out_ids can produce a tensor with 0 elements if n_outputs == 0
// This then causes a crash in this function
if (ne == 0) {
return true;
}
GGML_ASSERT(comm_ctx_v != nullptr);
ggml_backend_cuda_comm_context * comm_ctx = (ggml_backend_cuda_comm_context *) comm_ctx_v;
const size_t n_backends = comm_ctx->backends.size();
for (size_t i = 0; i < n_backends; ++i) {
GGML_ASSERT(tensors[i] != nullptr);
GGML_ASSERT(ggml_nelements(tensors[i]) == ne);
GGML_ASSERT(ggml_is_contiguously_allocated(tensors[i]));
}
// For small tensors, simply reduce them as FP32.
// The following heuristic for how "small" a tensor should be is based on RTX 4090s connected via 16x PCIe 4.0.
if ((n_backends <= 2 && ne < 32768) || (n_backends == 3 && ne < 131072) || (n_backends >= 4 && ne < 262144)) {
NCCL_CHECK(ncclGroupStart());
for (size_t i = 0; i < n_backends; ++i) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context;
NCCL_CHECK(ncclAllReduce(tensors[i]->data, tensors[i]->data, ne, ncclFloat, ncclSum, comm_ctx->comms[i], cuda_ctx->stream()));
}
NCCL_CHECK(ncclGroupEnd());
return true;
}
// For large tensors it's faster to compress them to BF16 for the reduction:
to_bf16_cuda_t to_bf16 = ggml_get_to_bf16_cuda(GGML_TYPE_F32);
to_fp32_cuda_t to_fp32 = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
ggml_cuda_pool_alloc<nv_bfloat16> tmp[GGML_CUDA_MAX_DEVICES];
for (size_t i = 0; i < n_backends; ++i) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context;
tmp[i].pool = &cuda_ctx->pool();
tmp[i].alloc(ne);
ggml_cuda_set_device(cuda_ctx->device);
to_bf16(tensors[i]->data, tmp[i].get(), ne, cuda_ctx->stream());
CUDA_CHECK(cudaGetLastError());
}
NCCL_CHECK(ncclGroupStart());
for (size_t i = 0; i < n_backends; ++i) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context;
NCCL_CHECK(ncclAllReduce(tmp[i].get(), tmp[i].get(), ne, ncclBfloat16, ncclSum, comm_ctx->comms[i], cuda_ctx->stream()));
}
NCCL_CHECK(ncclGroupEnd());
for (size_t i = 0; i < n_backends; ++i) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context;
ggml_cuda_set_device(cuda_ctx->device);
to_fp32(tmp[i].get(), (float *) tensors[i]->data, ne, cuda_ctx->stream());
CUDA_CHECK(cudaGetLastError());
}
return true;
#else
GGML_UNUSED_VARS(comm_ctx_v, tensors);
return false;
#endif // GGML_USE_NCCL
}
@ -4783,6 +4831,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
switch (a->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
@ -4820,6 +4869,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_TYPE_F32:
case GGML_TYPE_BF16:
case GGML_TYPE_I32:
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
@ -5220,8 +5270,14 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t
static void * ggml_backend_cuda_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) {
GGML_UNUSED(reg);
if (strcmp(name, "ggml_backend_allreduce_tensor") == 0) {
return (void *)ggml_backend_cuda_allreduce_tensor;
if (strcmp(name, "ggml_backend_comm_init") == 0) {
return (void *)ggml_backend_cuda_comm_init;
}
if (strcmp(name, "ggml_backend_comm_free") == 0) {
return (void *)ggml_backend_cuda_comm_free;
}
if (strcmp(name, "ggml_backend_comm_allreduce_tensor") == 0) {
return (void *)ggml_backend_cuda_comm_allreduce_tensor;
}
if (strcmp(name, "ggml_backend_split_buffer_type") == 0) {
return (void *)ggml_backend_cuda_split_buffer_type;

View File

@ -5,6 +5,9 @@
static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
switch (args.type_x) {
case GGML_TYPE_Q1_0:
mul_mat_q_case<GGML_TYPE_Q1_0>(ctx, args, stream);
break;
case GGML_TYPE_Q4_0:
mul_mat_q_case<GGML_TYPE_Q4_0>(ctx, args, stream);
break;
@ -270,6 +273,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
bool mmq_supported;
switch (type) {
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:

View File

@ -57,6 +57,8 @@ static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected b
static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
switch (type_x) {
case GGML_TYPE_Q1_0:
return MMQ_Q8_1_DS_LAYOUT_D4;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
return MMQ_Q8_1_DS_LAYOUT_DS4;
@ -185,6 +187,7 @@ static constexpr __device__ int get_mmq_y_device() {
static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
switch (type) {
case GGML_TYPE_Q1_0: return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0;
case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1;
case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
@ -229,6 +232,7 @@ static_assert(MMQ_MMA_TILE_X_K_NVFP4 % 8 == 4, "Wrong padding.");
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
switch (type) {
case GGML_TYPE_Q1_0: return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1;
case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
@ -302,6 +306,87 @@ static constexpr __device__ int mmq_get_nwarps_device() {
// ------------------------------------------------------------
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q1_0(
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
constexpr int nwarps = mmq_get_nwarps_device();
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);
#else
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
int * x_qs = (int *) x_tile;
float * x_df = (float *) (x_qs + txs.qs);
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
constexpr int blocks_per_iter = MMQ_ITER_K / QK1_0;
constexpr int threads_per_row = blocks_per_iter * QI1_0;
constexpr int nrows = warp_size / threads_per_row;
constexpr int scale_entries_per_block = QK1_0 / QK8_1;
constexpr int scale_entries_per_row = blocks_per_iter * scale_entries_per_block;
const int txi = threadIdx.x % threads_per_row;
const int kbx = txi / QI1_0;
const int kqsx = txi % QI1_0;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
if (need_check) {
i = min(i, i_max);
}
const block_q1_0 * bxi = (const block_q1_0 *) x + kbx0 + i*stride + kbx;
const int qs_offset = 4*kqsx;
const int qs0 = bxi->qs[qs_offset + 0] | (bxi->qs[qs_offset + 1] << 8) |
(bxi->qs[qs_offset + 2] << 16) | (bxi->qs[qs_offset + 3] << 24);
int unpacked_bytes[8];
#pragma unroll
for (int j = 0; j < 8; ++j) {
const int shift = j * 4;
const int bits4 = (qs0 >> shift) & 0x0F;
const int b0 = (bits4 & 0x01) ? 1 : -1;
const int b1 = (bits4 & 0x02) ? 1 : -1;
const int b2 = (bits4 & 0x04) ? 1 : -1;
const int b3 = (bits4 & 0x08) ? 1 : -1;
unpacked_bytes[j] = (b0 & 0xFF) | ((b1 & 0xFF) << 8) | ((b2 & 0xFF) << 16) | ((b3 & 0xFF) << 24);
}
const int dst_offset = kbx*(scale_entries_per_block*QI8_0) + kqsx*QI8_0;
#pragma unroll
for (int j = 0; j < 8; ++j) {
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + dst_offset + j] = unpacked_bytes[j];
#else
x_qs[i*(2*MMQ_TILE_NE_K + 1) + dst_offset + j] = unpacked_bytes[j];
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
}
const int ksx = threadIdx.x % scale_entries_per_row;
const int scale_block = ksx / scale_entries_per_block;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + threadIdx.y;
if (need_check) {
i = min(i, i_max);
}
const block_q1_0 * bxi = (const block_q1_0 *) x + kbx0 + i*stride + scale_block;
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + ksx] = bxi->d;
#else
x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + ksx] = bxi->d;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
}
}
template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
constexpr int nwarps = mmq_get_nwarps_device();
@ -3290,6 +3375,14 @@ static __device__ __forceinline__ void mmq_write_back_mma(
template <int mmq_x, int mmq_y, bool need_check, ggml_type type>
struct mmq_type_traits;
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q1_0> {
static constexpr int vdr = VDR_Q1_0_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q1_0<mmq_y, need_check>;
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
};
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_0> {
static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;

View File

@ -9,6 +9,7 @@ typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_
static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_Q1_0: return vec_dot_q1_0_q8_1;
case GGML_TYPE_Q4_0: return vec_dot_q4_0_q8_1;
case GGML_TYPE_Q4_1: return vec_dot_q4_1_q8_1;
case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1;
@ -36,6 +37,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) {
switch (type) {
case GGML_TYPE_Q1_0: return VDR_Q1_0_Q8_1_MMVQ;
case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ;
case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ;
case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ;
@ -886,6 +888,12 @@ static void mul_mat_vec_q_switch_type(
const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
const int ids_stride, cudaStream_t stream) {
switch (type_x) {
case GGML_TYPE_Q1_0:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q1_0>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_Q4_0:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_0>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,

View File

@ -32,6 +32,7 @@ SOURCE_FATTN_MMA_START = """// This file has been autogenerated by generate_cu_f
SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size_kq}, {head_size_v}, {ncols1}, {ncols2});\n"
TYPES_MMQ = [
"GGML_TYPE_Q1_0",
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
"GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
"GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S",

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmq.cuh"
DECL_MMQ_CASE(GGML_TYPE_Q1_0);

View File

@ -106,6 +106,9 @@ static __device__ __forceinline__ uint32_t unpack_ksigns(const uint8_t v) {
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
#define VDR_Q1_0_Q8_1_MMVQ 1 // Process one 32-element chunk at a time for parallelism
#define VDR_Q1_0_Q8_1_MMQ 4 // Q1_0 has 128 bits (4 ints) per block
#define VDR_Q4_0_Q8_1_MMVQ 2
#define VDR_Q4_0_Q8_1_MMQ 4
@ -669,6 +672,51 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
return d6 * sumf_d;
}
static __device__ __forceinline__ float vec_dot_q1_0_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
const block_q1_0 * bq1_0 = (const block_q1_0 *) vbq + kbx;
// Q1_0: 128 elements with ONE scale
// Q8_1: 32 elements per block with individual scales
// iqs selects which of the 4 chunks of 32 elements to process (0-3)
const float d1 = bq1_0->d;
// Process only the chunk specified by iqs
const block_q8_1 * bq8_1_chunk = bq8_1 + iqs;
// Load 32 bits (4 bytes) for this chunk from Q1_0
const int offset = iqs * 4;
const int v = bq1_0->qs[offset + 0] | (bq1_0->qs[offset + 1] << 8) |
(bq1_0->qs[offset + 2] << 16) | (bq1_0->qs[offset + 3] << 24);
// Unpack 32 bits into 32 signed values (-1 or +1)
int vi_bytes[8];
#pragma unroll
for (int j = 0; j < 8; ++j) {
const int shift = j * 4;
const int bits4 = (v >> shift) & 0x0F;
const int b0 = (bits4 & 0x01) ? 1 : -1;
const int b1 = (bits4 & 0x02) ? 1 : -1;
const int b2 = (bits4 & 0x04) ? 1 : -1;
const int b3 = (bits4 & 0x08) ? 1 : -1;
vi_bytes[j] = (b0 & 0xFF) | ((b1 & 0xFF) << 8) | ((b2 & 0xFF) << 16) | ((b3 & 0xFF) << 24);
}
// Compute dot product for this 32-element chunk
int sumi = 0;
#pragma unroll
for (int j = 0; j < 8; ++j) {
const int u = get_int_b4(bq8_1_chunk->qs, j);
sumi = ggml_cuda_dp4a(vi_bytes[j], u, sumi);
}
// Apply Q1_0's single scale and this chunk's Q8_1 scale
const float d8 = __low2float(bq8_1_chunk->ds);
return d1 * d8 * sumi;
}
static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {

View File

@ -7,3 +7,26 @@ ggml_add_backend_library(ggml-rpc
if (WIN32)
target_link_libraries(ggml-rpc PRIVATE ws2_32)
endif()
# RDMA auto-detection (Linux only, requires libibverbs)
if (NOT WIN32 AND NOT APPLE)
find_library(IBVERBS_LIB ibverbs)
if (IBVERBS_LIB)
option(GGML_RPC_RDMA "ggml: enable RDMA transport for RPC" ON)
else()
option(GGML_RPC_RDMA "ggml: enable RDMA transport for RPC" OFF)
endif()
else()
set(GGML_RPC_RDMA OFF CACHE BOOL "RDMA not available on this platform" FORCE)
endif()
if (GGML_RPC_RDMA)
if (NOT IBVERBS_LIB)
find_library(IBVERBS_LIB ibverbs REQUIRED)
endif()
target_compile_definitions(ggml-rpc PRIVATE GGML_RPC_RDMA)
target_link_libraries(ggml-rpc PRIVATE ${IBVERBS_LIB})
message(STATUS " RDMA transport enabled (auto-detected)")
else()
message(STATUS " RDMA transport disabled")
endif()

View File

@ -3,7 +3,9 @@
#include "ggml-backend-impl.h"
#include "ggml-cpp.h"
#include <array>
#include <cinttypes>
#include <optional>
#include <string>
#include <vector>
#include <memory>
@ -31,6 +33,14 @@
#include <filesystem>
#include <algorithm>
#ifdef GGML_RPC_RDMA
# include <infiniband/verbs.h>
# include <time.h>
# ifndef _WIN32
# include <poll.h>
# endif
#endif // GGML_RPC_RDMA
static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG");
#define LOG_DBG(...) \
@ -49,17 +59,116 @@ typedef int sockfd_t;
#endif
// cross-platform socket
#ifdef GGML_RPC_RDMA
static constexpr size_t RDMA_CHUNK = 256 * 1024; // 256 KiB per send/recv (fits default 8 MiB memlock)
static constexpr int RDMA_RX_DEPTH = 24; // pre-posted recv ring: 24 × 256 KiB = 6 MiB
static constexpr size_t RDMA_GID_SIZE = 16; // RoCE GID / IB GID is always 16 bytes
using rdma_gid_t = std::array<uint8_t, RDMA_GID_SIZE>;
struct rdma_conn {
struct ibv_context * ctx = nullptr;
struct ibv_pd * pd = nullptr;
struct ibv_cq * scq = nullptr; // send completions
struct ibv_cq * rcq = nullptr; // recv completions
struct ibv_qp * qp = nullptr;
void * tx_buf = nullptr;
struct ibv_mr * tx_mr = nullptr;
void * rx_buf = nullptr; // RDMA_RX_DEPTH × RDMA_CHUNK contiguous
struct ibv_mr * rx_mr = nullptr;
int rx_head = 0;
uint32_t max_inline = 0;
uint8_t * rx_slot(int i) const {
return static_cast<uint8_t *>(rx_buf) + static_cast<size_t>(i) * RDMA_CHUNK;
}
bool post_rx(int i) {
struct ibv_sge sge = {};
sge.addr = (uintptr_t)rx_slot(i);
sge.length = RDMA_CHUNK;
sge.lkey = rx_mr->lkey;
struct ibv_recv_wr wr = {}, * bad = nullptr;
wr.wr_id = (uint64_t)i;
wr.sg_list = &sge;
wr.num_sge = 1;
return ibv_post_recv(qp, &wr, &bad) == 0;
}
~rdma_conn() {
if (tx_mr) ibv_dereg_mr(tx_mr);
if (rx_mr) ibv_dereg_mr(rx_mr);
free(tx_buf);
free(rx_buf);
if (qp) ibv_destroy_qp(qp);
if (scq) ibv_destroy_cq(scq);
if (rcq) ibv_destroy_cq(rcq);
if (pd) ibv_dealloc_pd(pd);
if (ctx) ibv_close_device(ctx);
}
};
// Local RDMA parameters captured during the probe phase and later consumed
// by rdma_activate() after the remote side's caps arrive via HELLO.
struct rdma_local_info {
uint32_t qpn = 0;
uint32_t psn = 0;
uint8_t gid[RDMA_GID_SIZE] = {};
uint8_t ib_port = 0;
int gid_idx = 0;
enum ibv_mtu path_mtu = IBV_MTU_1024;
};
#endif // GGML_RPC_RDMA
// conn_caps size for transport-agnostic capability exchange
static constexpr size_t RPC_CONN_CAPS_SIZE = 24;
// conn_caps RDMA layout helper
#ifdef GGML_RPC_RDMA
struct rdma_caps {
uint32_t qpn;
uint32_t psn;
uint8_t gid[RDMA_GID_SIZE];
};
static_assert(sizeof(rdma_caps) == RPC_CONN_CAPS_SIZE, "rdma_caps must match conn_caps size");
#endif // GGML_RPC_RDMA
// Forward declarations for transport function pointers
struct socket_t;
static bool tcp_send_impl(socket_t * sock, const void * data, size_t size);
static bool tcp_recv_impl(socket_t * sock, void * data, size_t size);
struct socket_t {
sockfd_t fd;
bool (*fn_send)(socket_t *, const void *, size_t) = tcp_send_impl;
bool (*fn_recv)(socket_t *, void *, size_t) = tcp_recv_impl;
#ifdef GGML_RPC_RDMA
std::unique_ptr<rdma_conn> rdma;
rdma_local_info rdma_local = {};
#endif // GGML_RPC_RDMA
socket_t(sockfd_t fd) : fd(fd) {}
~socket_t() {
#ifdef GGML_RPC_RDMA
rdma.reset();
#endif // GGML_RPC_RDMA
LOG_DBG("[%s] closing socket %d\n", __func__, this->fd);
#ifdef _WIN32
closesocket(this->fd);
if (fd != INVALID_SOCKET) closesocket(this->fd);
#else
close(this->fd);
if (fd >= 0) close(this->fd);
#endif
}
// Advertise local transport capabilities into conn_caps.
// May probe RDMA and store the probe on this socket for update_caps.
void get_caps(uint8_t * caps);
// Activate transport upgrade based on remote conn_caps using the probe
// previously stored by get_caps.
void update_caps(const uint8_t * remote_caps);
};
// macro for nicer error messages on server crash
@ -115,10 +224,16 @@ static_assert(RPC_CMD_HELLO == 14, "RPC_CMD_HELLO must be always 14");
// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
struct rpc_msg_hello_req {
uint8_t conn_caps[RPC_CONN_CAPS_SIZE];
};
struct rpc_msg_hello_rsp {
uint8_t major;
uint8_t minor;
uint8_t patch;
uint8_t padding;
uint8_t conn_caps[RPC_CONN_CAPS_SIZE];
};
struct rpc_msg_device_count_rsp {
@ -414,27 +529,414 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
return true;
}
static bool send_msg(sockfd_t sockfd, const void * msg, size_t msg_size) {
if (!send_data(sockfd, &msg_size, sizeof(msg_size))) {
return false;
}
return send_data(sockfd, msg, msg_size);
// TCP transport implementations (for function-pointer dispatch)
static bool tcp_send_impl(socket_t * sock, const void * data, size_t size) {
return send_data(sock->fd, data, size);
}
static bool recv_msg(sockfd_t sockfd, void * msg, size_t msg_size) {
static bool tcp_recv_impl(socket_t * sock, void * data, size_t size) {
return recv_data(sock->fd, data, size);
}
// RDMA transport (performance-optimized, auto-negotiated)
#ifdef GGML_RPC_RDMA
static bool rdma_send_impl(socket_t * sock, const void * data, size_t size);
static bool rdma_recv_impl(socket_t * sock, void * data, size_t size);
static inline bool tcp_peer_closed(int fd) {
if (fd < 0) return false;
#ifndef _WIN32
struct pollfd pfd = { fd, POLLIN | POLLRDHUP, 0 };
int r = poll(&pfd, 1, 0);
return r > 0 && (pfd.revents & (POLLHUP | POLLERR | POLLRDHUP));
#else
return false;
#endif
}
static inline bool rdma_poll(struct ibv_cq * cq, struct ibv_wc * wc, int tcp_fd) {
for (uint64_t s = 0; ; s++) {
int n = ibv_poll_cq(cq, 1, wc);
if (n > 0) {
if (wc->status != IBV_WC_SUCCESS) {
GGML_LOG_ERROR("RDMA CQ wc error: status=%d (%s) vendor_err=0x%x\n",
wc->status, ibv_wc_status_str(wc->status), wc->vendor_err);
}
return wc->status == IBV_WC_SUCCESS;
}
if (n < 0) return false;
if ((s & 0xFFFFF) == 0 && s > 0) {
if (tcp_peer_closed(tcp_fd)) {
return false;
}
}
}
}
static bool rdma_send(rdma_conn * c, const void * data, size_t size, int tcp_fd) {
const uint8_t * src = (const uint8_t *)data;
size_t rem = size;
while (rem > 0) {
size_t chunk = std::min(rem, RDMA_CHUNK);
struct ibv_sge sge = {};
struct ibv_send_wr wr = {}, * bad = nullptr;
wr.opcode = IBV_WR_SEND;
wr.sg_list = &sge;
wr.num_sge = 1;
if (chunk <= c->max_inline) {
sge.addr = (uintptr_t)src;
sge.length = chunk;
wr.send_flags = IBV_SEND_SIGNALED | IBV_SEND_INLINE;
} else {
memcpy(c->tx_buf, src, chunk);
sge.addr = (uintptr_t)c->tx_buf;
sge.length = chunk;
sge.lkey = c->tx_mr->lkey;
wr.send_flags = IBV_SEND_SIGNALED;
}
if (ibv_post_send(c->qp, &wr, &bad) != 0) return false;
struct ibv_wc wc;
if (!rdma_poll(c->scq, &wc, tcp_fd)) return false;
src += chunk;
rem -= chunk;
}
return true;
}
static bool rdma_recv(rdma_conn * c, void * data, size_t size, int tcp_fd) {
uint8_t * dst = (uint8_t *)data;
size_t rem = size;
while (rem > 0) {
struct ibv_wc wc;
if (!rdma_poll(c->rcq, &wc, tcp_fd)) return false;
int slot = (int)wc.wr_id;
size_t got = wc.byte_len;
memcpy(dst, c->rx_slot(slot), got);
if (!c->post_rx(slot)) return false;
dst += got;
rem -= got;
}
return true;
}
static bool rdma_send_impl(socket_t * sock, const void * data, size_t size) {
return rdma_send(sock->rdma.get(), data, size, sock->fd);
}
static bool rdma_recv_impl(socket_t * sock, void * data, size_t size) {
return rdma_recv(sock->rdma.get(), data, size, sock->fd);
}
// Build a RoCE GID-shaped 16-byte target from a TCP socket's local address.
// Used to match the socket's local IP against the kernel's GID table so that
// a single memcmp handles IPv4, IPv4-mapped IPv6, and native IPv6 uniformly:
// AF_INET -> ::ffff:a.b.c.d (bytes 10-11 = 0xff, last 4 = IPv4)
// AF_INET6 (IPv4-mapped) -> ::ffff:a.b.c.d (already in GID shape)
// AF_INET6 (native v6) -> the 16-byte IPv6 address as-is
// Returns std::nullopt on unsupported family or getsockname failure.
static std::optional<rdma_gid_t> rdma_build_target_gid(sockfd_t tcp_fd) {
sockaddr_storage addr = {};
socklen_t addr_len = sizeof(addr);
if (getsockname(tcp_fd, reinterpret_cast<sockaddr *>(&addr), &addr_len) != 0) {
return std::nullopt;
}
rdma_gid_t target = {};
if (addr.ss_family == AF_INET) {
const auto * a = reinterpret_cast<const sockaddr_in *>(&addr);
target[10] = 0xff;
target[11] = 0xff;
memcpy(&target[12], &a->sin_addr, 4);
return target;
}
if (addr.ss_family == AF_INET6) {
const auto * a = reinterpret_cast<const sockaddr_in6 *>(&addr);
memcpy(target.data(), &a->sin6_addr, RDMA_GID_SIZE);
return target;
}
return std::nullopt;
}
static rdma_conn * rdma_probe(sockfd_t tcp_fd, rdma_local_info * out) {
const char * dev_env = std::getenv("GGML_RDMA_DEV");
const char * gid_env = std::getenv("GGML_RDMA_GID");
auto target_gid = rdma_build_target_gid(tcp_fd);
if (!target_gid) {
return nullptr;
}
const uint8_t ib_port = 1;
int num_devs = 0;
ibv_device ** devs = ibv_get_device_list(&num_devs);
if (!devs || num_devs == 0) return nullptr;
ibv_context * ibctx = nullptr;
const char * matched_dev = nullptr;
int gid_idx = gid_env ? atoi(gid_env) : -1;
int gid_version = IBV_GID_TYPE_IB; // 0 = unknown/IB
for (int d = 0; d < num_devs; d++) {
const char * dn = ibv_get_device_name(devs[d]);
if (dev_env && strcmp(dev_env, dn) != 0) continue;
ibv_context * ctx = ibv_open_device(devs[d]);
if (!ctx) continue;
ibv_port_attr pa;
if (ibv_query_port(ctx, ib_port, &pa) != 0) { ibv_close_device(ctx); continue; }
int found_gid = gid_idx;
int found_version = IBV_GID_TYPE_IB;
if (found_gid < 0) {
// Find a GID on this port whose bytes equal the local TCP address
// (IPv4 or IPv6). Prefer RoCE v2 (UDP/IP, L3-routable) over v1
// (raw Ethernet, same-L2 only) so silent hangs on L3-routed paths
// are avoided. ibv_query_gid_ex returns gid+type in one call.
int v2_idx = -1;
int v1_idx = -1;
for (int i = 0; i < pa.gid_tbl_len; i++) {
ibv_gid_entry entry = {};
if (ibv_query_gid_ex(ctx, ib_port, i, &entry, 0) != 0) continue;
if (memcmp(entry.gid.raw, target_gid->data(), RDMA_GID_SIZE) != 0) continue;
if (entry.gid_type == IBV_GID_TYPE_ROCE_V2 && v2_idx < 0) {
v2_idx = i;
} else if (entry.gid_type == IBV_GID_TYPE_ROCE_V1 && v1_idx < 0) {
v1_idx = i;
}
}
if (v2_idx >= 0) {
found_gid = v2_idx;
found_version = IBV_GID_TYPE_ROCE_V2;
} else if (v1_idx >= 0) {
found_gid = v1_idx;
found_version = IBV_GID_TYPE_ROCE_V1;
}
} else {
// Explicit GID index from GGML_RDMA_GID — fetch its type for logging.
ibv_gid_entry entry = {};
if (ibv_query_gid_ex(ctx, ib_port, found_gid, &entry, 0) == 0) {
found_version = entry.gid_type;
}
}
if (found_gid >= 0) {
ibctx = ctx;
gid_idx = found_gid;
gid_version = found_version;
matched_dev = dn;
out->path_mtu = pa.active_mtu;
break;
}
ibv_close_device(ctx);
}
ibv_free_device_list(devs);
if (!ibctx) return nullptr;
out->ib_port = ib_port;
out->gid_idx = gid_idx;
// unique_ptr owns ibctx and every subsequent resource via ~rdma_conn(),
// so each failure path is a plain `return nullptr;`.
auto c = std::make_unique<rdma_conn>();
c->ctx = ibctx;
c->pd = ibv_alloc_pd(ibctx);
if (!c->pd) return nullptr;
c->scq = ibv_create_cq(ibctx, 16, nullptr, nullptr, 0);
c->rcq = ibv_create_cq(ibctx, RDMA_RX_DEPTH + 4, nullptr, nullptr, 0);
if (!c->scq || !c->rcq) return nullptr;
ibv_qp_init_attr qia = {};
qia.send_cq = c->scq;
qia.recv_cq = c->rcq;
qia.qp_type = IBV_QPT_RC;
qia.cap.max_send_wr = 4;
qia.cap.max_recv_wr = RDMA_RX_DEPTH + 4;
qia.cap.max_send_sge = 1;
qia.cap.max_recv_sge = 1;
qia.cap.max_inline_data = 256;
c->qp = ibv_create_qp(c->pd, &qia);
if (!c->qp) return nullptr;
c->max_inline = qia.cap.max_inline_data;
c->tx_buf = aligned_alloc(4096, RDMA_CHUNK);
c->rx_buf = aligned_alloc(4096, static_cast<size_t>(RDMA_RX_DEPTH) * RDMA_CHUNK);
if (!c->tx_buf || !c->rx_buf) return nullptr;
c->tx_mr = ibv_reg_mr(c->pd, c->tx_buf, RDMA_CHUNK, IBV_ACCESS_LOCAL_WRITE);
c->rx_mr = ibv_reg_mr(c->pd, c->rx_buf, static_cast<size_t>(RDMA_RX_DEPTH) * RDMA_CHUNK,
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
if (!c->tx_mr || !c->rx_mr) return nullptr;
ibv_gid local_gid;
if (ibv_query_gid(ibctx, ib_port, gid_idx, &local_gid) != 0) return nullptr;
out->qpn = c->qp->qp_num;
out->psn = c->qp->qp_num & 0xffffff;
memcpy(out->gid, &local_gid, RDMA_GID_SIZE);
const char * ver_str = "";
if (gid_version == IBV_GID_TYPE_ROCE_V2) {
ver_str = " RoCEv2";
} else if (gid_version == IBV_GID_TYPE_ROCE_V1) {
ver_str = " RoCEv1";
}
GGML_LOG_INFO("RDMA probed: dev=%s gid=%d%s qpn=%u inline=%u\n",
matched_dev, gid_idx, ver_str, out->qpn, c->max_inline);
return c.release();
}
// Phase 2: Given remote QPN/PSN/GID, transition QP: RESET->INIT->pre-post->RTR->RTS.
// On success, the connection is live and ready for rdma_send/rdma_recv.
static bool rdma_activate(rdma_conn * c, const rdma_local_info * local,
uint32_t remote_qpn, uint32_t remote_psn, const uint8_t * remote_gid) {
// RESET -> INIT
{
struct ibv_qp_attr a = {};
a.qp_state = IBV_QPS_INIT;
a.port_num = local->ib_port;
a.pkey_index = 0;
a.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_LOCAL_WRITE;
if (ibv_modify_qp(c->qp, &a,
IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) {
return false;
}
}
for (int i = 0; i < RDMA_RX_DEPTH; i++) {
if (!c->post_rx(i)) return false;
}
// INIT -> RTR
{
struct ibv_qp_attr a = {};
a.qp_state = IBV_QPS_RTR;
a.path_mtu = local->path_mtu;
a.dest_qp_num = remote_qpn;
a.rq_psn = remote_psn;
a.max_dest_rd_atomic = 1;
a.min_rnr_timer = 1;
a.ah_attr.is_global = 1;
memcpy(&a.ah_attr.grh.dgid, remote_gid, RDMA_GID_SIZE);
a.ah_attr.grh.hop_limit = 1;
a.ah_attr.grh.sgid_index = local->gid_idx;
a.ah_attr.dlid = 0;
a.ah_attr.port_num = local->ib_port;
if (ibv_modify_qp(c->qp, &a,
IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN |
IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER) != 0) {
return false;
}
}
// RTR -> RTS
{
struct ibv_qp_attr a = {};
a.qp_state = IBV_QPS_RTS;
a.timeout = 14;
a.retry_cnt = 7;
a.rnr_retry = 7;
a.sq_psn = local->psn;
a.max_rd_atomic = 1;
if (ibv_modify_qp(c->qp, &a,
IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY |
IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC) != 0) {
return false;
}
}
GGML_LOG_INFO("RDMA activated: qpn=%u->%u mtu=%d rx_depth=%d\n",
local->qpn, remote_qpn, 128 << local->path_mtu, RDMA_RX_DEPTH);
return true;
}
#endif // GGML_RPC_RDMA
// ---------------------------------------------------------------------------
// socket_t transport capability methods
// ---------------------------------------------------------------------------
void socket_t::get_caps(uint8_t * caps) {
memset(caps, 0, RPC_CONN_CAPS_SIZE);
#ifdef GGML_RPC_RDMA
rdma_local = {};
rdma.reset(rdma_probe(fd, &rdma_local));
if (rdma) {
rdma_caps rc = {};
rc.qpn = rdma_local.qpn;
rc.psn = rdma_local.psn;
memcpy(rc.gid, rdma_local.gid, RDMA_GID_SIZE);
memcpy(caps, &rc, sizeof(rc));
}
#endif // GGML_RPC_RDMA
}
void socket_t::update_caps(const uint8_t * remote_caps) {
#ifdef GGML_RPC_RDMA
if (!rdma) {
return;
}
rdma_caps rc = {};
memcpy(&rc, remote_caps, sizeof(rc));
if (rc.qpn == 0) {
rdma.reset();
return;
}
if (rdma_activate(rdma.get(), &rdma_local, rc.qpn, rc.psn, rc.gid)) {
fn_send = rdma_send_impl;
fn_recv = rdma_recv_impl;
} else {
GGML_LOG_ERROR("RDMA activate failed, staying on TCP\n");
rdma.reset();
}
#else
(void)remote_caps;
#endif // GGML_RPC_RDMA
}
// unified transport dispatch (via function pointers)
static bool send_data(socket_t * sock, const void * data, size_t size) {
return sock->fn_send(sock, data, size);
}
static bool recv_data(socket_t * sock, void * data, size_t size) {
return sock->fn_recv(sock, data, size);
}
static bool send_msg(socket_t * sock, const void * msg, size_t msg_size) {
if (!send_data(sock, &msg_size, sizeof(msg_size))) {
return false;
}
return send_data(sock, msg, msg_size);
}
static bool recv_msg(socket_t * sock, void * msg, size_t msg_size) {
uint64_t size;
if (!recv_data(sockfd, &size, sizeof(size))) {
if (!recv_data(sock, &size, sizeof(size))) {
return false;
}
if (size != msg_size) {
return false;
}
return recv_data(sockfd, msg, msg_size);
return recv_data(sock, msg, msg_size);
}
static bool recv_msg(sockfd_t sockfd, std::vector<uint8_t> & input) {
static bool recv_msg(socket_t * sock, std::vector<uint8_t> & input) {
uint64_t size;
if (!recv_data(sockfd, &size, sizeof(size))) {
if (!recv_data(sock, &size, sizeof(size))) {
return false;
}
try {
@ -443,7 +945,7 @@ static bool recv_msg(sockfd_t sockfd, std::vector<uint8_t> & input) {
GGML_LOG_ERROR("Failed to allocate input buffer of size %" PRIu64 "\n", size);
return false;
}
return recv_data(sockfd, input.data(), size);
return recv_data(sock, input.data(), size);
}
static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
@ -452,7 +954,11 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int
return false;
}
host = endpoint.substr(0, pos);
port = std::stoi(endpoint.substr(pos + 1));
try {
port = std::stoi(endpoint.substr(pos + 1));
} catch (...) {
return false;
}
return true;
}
@ -460,13 +966,13 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int
// No response
static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size) {
uint8_t cmd_byte = cmd;
if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {
if (!send_data(sock.get(), &cmd_byte, sizeof(cmd_byte))) {
return false;
}
if (!send_data(sock->fd, &input_size, sizeof(input_size))) {
if (!send_data(sock.get(), &input_size, sizeof(input_size))) {
return false;
}
if (!send_data(sock->fd, input, input_size)) {
if (!send_data(sock.get(), input, input_size)) {
return false;
}
return true;
@ -478,16 +984,14 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
if (!send_rpc_cmd(sock, cmd, input, input_size)) {
return false;
}
// TODO: currently the output_size is always known, do we need support for commands with variable output size?
// even if we do, we can skip sending output_size from the server for commands with known output size
uint64_t out_size;
if (!recv_data(sock->fd, &out_size, sizeof(out_size))) {
if (!recv_data(sock.get(), &out_size, sizeof(out_size))) {
return false;
}
if (out_size != output_size) {
return false;
}
if (!recv_data(sock->fd, output, output_size)) {
if (!recv_data(sock.get(), output, output_size)) {
return false;
}
return true;
@ -495,17 +999,25 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
// RPC client-side implementation
static bool check_server_version(const std::shared_ptr<socket_t> & sock) {
rpc_msg_hello_rsp response;
bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
// Performs HELLO handshake with transport auto-negotiation.
// Advertises local capabilities via conn_caps; if the server responds with
// matching capabilities, the socket is upgraded transparently.
static bool negotiate_hello(const std::shared_ptr<socket_t> & sock) {
rpc_msg_hello_req request = {};
rpc_msg_hello_rsp response = {};
sock->get_caps(request.conn_caps);
bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, &request, sizeof(request), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
GGML_LOG_ERROR("RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
GGML_LOG_ERROR("RPC server version mismatch: %d.%d.%d\n",
response.major, response.minor, response.patch);
return false;
}
if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) {
GGML_LOG_INFO("WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
}
sock->update_caps(response.conn_caps);
return true;
}
@ -527,6 +1039,7 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
GGML_LOG_ERROR("Failed to parse endpoint: %s\n", endpoint.c_str());
return nullptr;
}
#ifdef _WIN32
if (!initialized) {
WSADATA wsaData;
@ -543,10 +1056,10 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
if (sock == nullptr) {
return nullptr;
}
if (!check_server_version(sock)) {
if (!negotiate_hello(sock)) {
return nullptr;
}
LOG_DBG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
LOG_DBG("[%s] connected to %s\n", __func__, endpoint.c_str());
sockets[endpoint] = sock;
return sock;
}
@ -1597,25 +2110,46 @@ rpc_server::~rpc_server() {
}
static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const char * cache_dir,
sockfd_t sockfd) {
socket_t * sockfd) {
rpc_server server(backends, cache_dir);
uint8_t cmd;
if (!recv_data(sockfd, &cmd, 1)) {
return;
}
// the first command sent by the client must be HELLO
if (cmd != RPC_CMD_HELLO) {
GGML_LOG_ERROR("Expected HELLO command, update client\n");
return;
}
if (!recv_msg(sockfd, nullptr, 0)) {
// Read input_size and validate protocol version
uint64_t hello_input_size;
if (!recv_data(sockfd, &hello_input_size, sizeof(hello_input_size))) {
return;
}
rpc_msg_hello_rsp response;
server.hello(response);
if (!send_msg(sockfd, &response, sizeof(response))) {
if (hello_input_size != sizeof(rpc_msg_hello_req)) {
GGML_LOG_ERROR("HELLO request size mismatch (%zu vs %zu) — client needs upgrade to protocol v%d.x\n",
(size_t)hello_input_size, sizeof(rpc_msg_hello_req), RPC_PROTO_MAJOR_VERSION);
return;
}
rpc_msg_hello_req req = {};
if (!recv_data(sockfd, &req, sizeof(req))) {
return;
}
rpc_msg_hello_rsp rsp = {};
server.hello(rsp);
// Advertise server transport capabilities based on client's caps
sockfd->get_caps(rsp.conn_caps);
if (!send_msg(sockfd, &rsp, sizeof(rsp))) {
return;
}
// Activate transport upgrade using client's caps
sockfd->update_caps(req.conn_caps);
while (true) {
if (!recv_data(sockfd, &cmd, 1)) {
break;
@ -1884,6 +2418,12 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir
if (!parse_endpoint(endpoint, host, port)) {
return;
}
#ifdef GGML_RPC_RDMA
printf(" transport : TCP (RDMA auto-negotiate enabled)\n");
#else
printf(" transport : TCP\n");
#endif // GGML_RPC_RDMA
#ifdef _WIN32
{
WSADATA wsaData;
@ -1907,7 +2447,7 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir
}
printf("Accepted client connection\n");
fflush(stdout);
rpc_serve_client(backends, cache_dir, client_socket->fd);
rpc_serve_client(backends, cache_dir, client_socket.get());
printf("Client connection closed\n");
fflush(stdout);
}

View File

@ -1394,7 +1394,7 @@ struct vk_op_im2col_push_constants {
uint32_t IW; uint32_t IH;
uint32_t OW; uint32_t OH;
uint32_t KW; uint32_t KH;
uint32_t pelements;
uint32_t OH_batch;
uint32_t CHW;
int32_t s0; int32_t s1;
int32_t p0; int32_t p1;
@ -10064,7 +10064,13 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
const uint32_t batch = src1->ne[is_2D ? 3 : 2];
elements = { OW * KW * KH, OH, batch * IC };
const uint32_t CHW = IC * KH * KW;
// Cap X workgroups to limit concurrent IC channel reads.
// The shader loops over X to cover the full CHW dimension.
// AMD prefers a lower limit
const uint32_t min_cap = ctx->device->vendor_id == VK_VENDOR_ID_AMD ? 512u : 4096u;
const uint32_t x_elements = std::min(CHW, std::max(min_cap, OW * KH * KW));
elements = { x_elements, OW, OH * batch };
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
} break;
@ -11727,7 +11733,6 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
const uint32_t pelements = OW * KW * KH;
const uint32_t batch = src1->ne[is_2D ? 3 : 2];
const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
@ -11739,7 +11744,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
dst_addr,
batch_offset, offset_delta,
IC, IW, IH, OW, OH, KW, KH,
pelements,
OH * batch,
IC * KH * KW,
s0, s1, p0, p1, d0, d1, batch * IC
});

View File

@ -13,7 +13,7 @@ layout (push_constant) uniform parameter
uint IW; uint IH;
uint OW; uint OH;
uint KW; uint KH;
uint pelements;
uint OH_batch;
uint CHW;
int s0; int s1;
int p0; int p1;
@ -34,82 +34,60 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
layout (buffer_reference) buffer D_ptr {D_TYPE d;};
#endif
void im2col(const uint y, const uint z) {
const uint gidx = gl_GlobalInvocationID.x;
void im2col(const uint ow, const uint z_idx) {
const uint oh = z_idx % p.OH;
const uint batch_idx = z_idx / p.OH;
const uint oh = y;
const uint batch = z / p.IC;
const uint ic = z % p.IC;
const uint gidx = gl_LocalInvocationID.x;
const uint src_batch = batch_idx * p.batch_offset;
const BDA_OFFSET_T dst_row = ((BDA_OFFSET_T(batch_idx) * p.OH + oh) * p.OW + ow) * p.CHW;
const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH);
const int oh_s1 = int(oh) * p.s1;
const uint ksize = p.OW * p.KH;
const uint KHKW = p.KH * p.KW;
const uint base_linear_idx = gidx * NUM_ITER;
uint wg_x = gl_WorkGroupID.x;
do {
const uint wg_offset = wg_x * 512;
uint current_kx = base_linear_idx / ksize;
const uint rem = base_linear_idx - (current_kx * ksize);
uint current_ky = rem / p.OW;
uint current_ix = rem % p.OW;
[[unroll]] for (uint i = 0; i < NUM_ITER; ++i) {
const uint chw_idx = wg_offset + gidx + i * BLOCK_SIZE;
A_TYPE values[NUM_ITER];
BDA_OFFSET_T offset_dst[NUM_ITER];
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
values[idx] = A_TYPE(0);
}
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
const uint linear_idx = base_linear_idx + idx;
if (linear_idx >= p.pelements) {
continue;
}
const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0;
const uint iih = oh_s1 + current_ky * p.d1 - p.p1;
offset_dst[idx] = dst_base + BDA_OFFSET_T(current_ix) * p.CHW + current_ky * p.KW + current_kx;
if ((iih < p.IH) && (iiw < p.IW)) {
values[idx] = data_a[src_base + iih * p.IW + iiw];
}
if (++current_ix == p.OW) {
current_ix = 0;
if (++current_ky == p.KH) {
current_ky = 0;
current_kx++;
if (chw_idx >= p.CHW) {
return;
}
}
}
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
const uint ic = chw_idx / KHKW;
const uint rem = chw_idx - ic * KHKW;
const uint ky = rem / p.KW;
const uint kx = rem - ky * p.KW;
const uint linear_idx = base_linear_idx + idx;
const uint iiw = ow * p.s0 + kx * p.d0 - p.p0;
const uint iih = oh * p.s1 + ky * p.d1 - p.p1;
if (linear_idx >= p.pelements) {
continue;
}
A_TYPE val = A_TYPE(0);
if (iih < p.IH && iiw < p.IW) {
val = data_a[src_batch + ic * p.offset_delta + iih * p.IW + iiw];
}
#if BDA
D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst[idx]);
dst_addr.d = D_TYPE(values[idx]);
D_ptr out_ptr = D_ptr(p.dst_addr + D_SIZE * (dst_row + chw_idx));
out_ptr.d = D_TYPE(val);
#else
data_d[offset_dst[idx]] = D_TYPE(values[idx]);
data_d[dst_row + chw_idx] = D_TYPE(val);
#endif
}
}
wg_x += gl_NumWorkGroups.x;
} while (wg_x * 512 < p.CHW);
}
void main() {
uint y = gl_GlobalInvocationID.y;
while (y < p.OH) {
uint ow = gl_GlobalInvocationID.y;
while (ow < p.OW) {
uint z = gl_GlobalInvocationID.z;
while (z < p.batch_IC) {
im2col(y, z);
while (z < p.OH_batch) {
im2col(ow, z);
z += gl_NumWorkGroups.z;
}
y += gl_NumWorkGroups.y;
ow += gl_NumWorkGroups.y;
}
}

View File

@ -248,6 +248,27 @@ struct ggml_webgpu_ssm_conv_pipeline_key {
}
};
/** CONV 2D */
struct ggml_webgpu_conv2d_pipeline_key {
ggml_type weight_type;
ggml_type input_type;
ggml_type output_type;
bool operator==(const ggml_webgpu_conv2d_pipeline_key & other) const {
return weight_type == other.weight_type && input_type == other.input_type && output_type == other.output_type;
}
};
struct ggml_webgpu_conv2d_pipeline_key_hash {
size_t operator()(const ggml_webgpu_conv2d_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.weight_type);
ggml_webgpu_hash_combine(seed, key.input_type);
ggml_webgpu_hash_combine(seed, key.output_type);
return seed;
}
};
/** Gated Delta Net **/
struct ggml_webgpu_gated_delta_net_pipeline_key {
int type;
@ -831,6 +852,8 @@ class ggml_webgpu_shader_lib {
rope_pipelines;
std::unordered_map<ggml_webgpu_soft_max_pipeline_key, webgpu_pipeline, ggml_webgpu_soft_max_pipeline_key_hash>
soft_max_pipelines;
std::unordered_map<ggml_webgpu_conv2d_pipeline_key, webgpu_pipeline, ggml_webgpu_conv2d_pipeline_key_hash>
conv2d_pipelines;
public:
ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
@ -1115,8 +1138,7 @@ class ggml_webgpu_shader_lib {
std::string type_upper = type_str;
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
switch (key.src_type)
{
switch (key.src_type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q8_0:
@ -1136,9 +1158,9 @@ class ggml_webgpu_shader_lib {
break;
}
default:
{
defines.push_back(std::string("SRC_TYPE=") + type_str);
}
{
defines.push_back(std::string("SRC_TYPE=") + type_str);
}
}
defines.push_back("BYTE_HELPERS");
@ -1621,8 +1643,7 @@ class ggml_webgpu_shader_lib {
std::string type_upper = src0_name;
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
switch (context.src0->type)
{
switch (context.src0->type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q8_0:
@ -1642,9 +1663,9 @@ class ggml_webgpu_shader_lib {
break;
}
default:
{
defines.push_back(std::string("SRC0_TYPE=") + src0_name);
}
{
defines.push_back(std::string("SRC0_TYPE=") + src0_name);
}
}
defines.push_back("BYTE_HELPERS");
@ -2340,6 +2361,47 @@ class ggml_webgpu_shader_lib {
return soft_max_pipelines[key];
}
webgpu_pipeline get_conv2d_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_conv2d_pipeline_key key = {
.weight_type = context.src0->type,
.input_type = context.src1->type,
.output_type = context.dst->type,
};
auto it = conv2d_pipelines.find(key);
if (it != conv2d_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "conv_2d";
auto push_type_defines = [&](const char * prefix, ggml_type type) {
std::string s_prefix = prefix;
if (type == GGML_TYPE_F32) {
defines.push_back(s_prefix + "_F32");
} else if (type == GGML_TYPE_F16) {
defines.push_back(s_prefix + "_F16");
} else {
GGML_ABORT("Unsupported type for CONV_2D shader");
}
};
push_type_defines("WEIGHT", key.weight_type);
push_type_defines("INPUT", key.input_type);
push_type_defines("OUTPUT", key.output_type);
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_conv2d, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
conv2d_pipelines[key] = pipeline;
return conv2d_pipelines[key];
}
private:
static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
std::string shader_code,

View File

@ -8,6 +8,7 @@
#include "ggml-backend-impl.h"
#include "ggml-impl.h"
#include "ggml-webgpu-shader-lib.hpp"
#include "ggml.h"
#ifdef __EMSCRIPTEN__
# include <emscripten/emscripten.h>
@ -83,7 +84,7 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim
#define WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN 10u
#define WEBGPU_RUNTIME_WAIT_TIMEOUT_MS 30000u
#define WEBGPU_RUNTIME_WAIT_TIMEOUT_NS (WEBGPU_RUNTIME_WAIT_TIMEOUT_MS * 1e6)
#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
#define WEBGPU_PARAMS_BUF_SIZE_BYTES 256 // enough for 64 parameters
#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
@ -1014,6 +1015,98 @@ static webgpu_encoded_op ggml_webgpu_solve_tri(webgpu_context & ctx,
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, wg_y);
}
static webgpu_encoded_op ggml_webgpu_conv_2d(webgpu_context & ctx,
wgpu::CommandEncoder & encoder,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst) {
const int32_t s0 = ggml_get_op_params_i32(dst, 0);
const int32_t s1 = ggml_get_op_params_i32(dst, 1);
const int32_t p0 = ggml_get_op_params_i32(dst, 2);
const int32_t p1 = ggml_get_op_params_i32(dst, 3);
const int32_t d0 = ggml_get_op_params_i32(dst, 4);
const int32_t d1 = ggml_get_op_params_i32(dst, 5);
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
(uint32_t) src0->ne[0],
(uint32_t) src0->ne[1],
(uint32_t) src0->ne[2],
(uint32_t) src1->ne[0],
(uint32_t) src1->ne[1],
(uint32_t) dst->ne[0],
(uint32_t) dst->ne[1],
(uint32_t) dst->ne[2],
(uint32_t) dst->ne[3],
(uint32_t) s0,
(uint32_t) s1,
(uint32_t) p0,
(uint32_t) p1,
(uint32_t) d0,
(uint32_t) d1,
};
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
{ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
{ .binding = 2,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) },
};
uint32_t max_wg_size =
std::min((uint32_t) WEBGPU_MAX_WG_SIZE, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupSizeX);
uint32_t wg_size =
std::min((uint32_t) ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, max_wg_size);
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src0,
.src1 = src1,
.dst = dst,
.max_wg_size = wg_size,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_conv2d_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t n_out = ggml_nelements(dst);
uint32_t total_wg = CEIL_DIV(n_out, decisions->wg_size);
uint32_t max_wg = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
uint32_t wg_x = std::min(total_wg, max_wg);
uint32_t wg_y = CEIL_DIV(total_wg, wg_x);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, wg_y);
}
static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx,
wgpu::CommandEncoder & encoder,
ggml_tensor * src0,
@ -2918,6 +3011,8 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_encode_node(webgpu_context
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
return ggml_webgpu_sum_rows(ctx, encoder, src0, node);
case GGML_OP_CONV_2D:
return ggml_webgpu_conv_2d(ctx, encoder, src0, src1, node);
default:
return std::nullopt;
}
@ -3884,6 +3979,11 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
case GGML_OP_SOLVE_TRI:
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
break;
case GGML_OP_CONV_2D:
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) &&
(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
break;
case GGML_OP_SSM_CONV:
supports_op = op->type == GGML_TYPE_F32;
break;

View File

@ -9,42 +9,65 @@ fn get_byte_i32(value: u32, index: u32) -> i32 {
#endif
#ifdef U32_DEQUANT_HELPERS
fn load_u16_at(
buf: ptr<storage, array<u32>, read_write>,
byte_offset: u32) -> u32 {
let word = buf[byte_offset / 4];
let shift = (byte_offset & 0x2) * 8;
return (word >> shift) & 0xFFFF;
#ifdef DECLARE_BYTE_LOADERS_SRC
fn load_u16_at_src(byte_offset: u32) -> u32 {
let word = src[byte_offset / 4u];
let shift = (byte_offset & 0x2u) * 8u;
return (word >> shift) & 0xFFFFu;
}
fn load_u32_at(
buf: ptr<storage, array<u32>, read_write>,
byte_offset: u32) -> u32 {
let word_idx = byte_offset / 4;
let shift = (byte_offset & 0x3) * 8;
let lo = buf[word_idx];
let hi = buf[word_idx + 1];
let shifted = (lo >> shift) | (hi << (32 - shift));
return select(shifted, lo, shift == 0);
fn load_u32_at_src(byte_offset: u32) -> u32 {
let word_idx = byte_offset / 4u;
let shift = (byte_offset & 0x3u) * 8u;
let lo = src[word_idx];
let hi = src[word_idx + 1u];
let shifted = (lo >> shift) | (hi << (32u - shift));
return select(shifted, lo, shift == 0u);
}
fn load_f16_at(
buf: ptr<storage, array<u32>, read_write>,
byte_offset: u32) -> f16 {
let packed = unpack2x16float(load_u16_at(buf, byte_offset));
fn load_f16_at_src(byte_offset: u32) -> f16 {
let packed = unpack2x16float(load_u16_at_src(byte_offset));
return f16(packed[0]);
}
fn load_f16_as_f32_at(
buf: ptr<storage, array<u32>, read_write>,
byte_offset: u32) -> f32 {
let word = buf[byte_offset / 4];
let shift = (byte_offset & 0x2) * 8;
let d_bits = (word >> shift) & 0xFFFF;
fn load_f16_as_f32_at_src(byte_offset: u32) -> f32 {
let word = src[byte_offset / 4u];
let shift = (byte_offset & 0x2u) * 8u;
let d_bits = (word >> shift) & 0xFFFFu;
return unpack2x16float(d_bits)[0];
}
#endif
#ifdef DECLARE_BYTE_LOADERS_SRC0
fn load_u16_at_src0(byte_offset: u32) -> u32 {
let word = src0[byte_offset / 4u];
let shift = (byte_offset & 0x2u) * 8u;
return (word >> shift) & 0xFFFFu;
}
fn load_u32_at_src0(byte_offset: u32) -> u32 {
let word_idx = byte_offset / 4u;
let shift = (byte_offset & 0x3u) * 8u;
let lo = src0[word_idx];
let hi = src0[word_idx + 1u];
let shifted = (lo >> shift) | (hi << (32u - shift));
return select(shifted, lo, shift == 0u);
}
fn load_f16_at_src0(byte_offset: u32) -> f16 {
let packed = unpack2x16float(load_u16_at_src0(byte_offset));
return f16(packed[0]);
}
fn load_f16_as_f32_at_src0(byte_offset: u32) -> f32 {
let word = src0[byte_offset / 4u];
let shift = (byte_offset & 0x2u) * 8u;
let d_bits = (word >> shift) & 0xFFFFu;
return unpack2x16float(d_bits)[0];
}
#endif
#endif
#ifdef Q4_1_T

View File

@ -0,0 +1,166 @@
#include "common_decls.tmpl"
enable f16;
@group(0) @binding(0)
#if defined(WEIGHT_F32)
var<storage, read_write> weights: array<f32>;
#elif defined(WEIGHT_F16)
var<storage, read_write> weights: array<f16>;
#endif
@group(0) @binding(1)
#if defined(INPUT_F32)
var<storage, read_write> input: array<f32>;
#elif defined(INPUT_F16)
var<storage, read_write> input: array<f16>;
#endif
@group(0) @binding(2)
#if defined(OUTPUT_F32)
var<storage, read_write> output: array<f32>;
#elif defined(OUTPUT_F16)
var<storage, read_write> output: array<f16>;
#endif
struct Params {
offset_w: u32,
offset_i: u32,
offset_o: u32,
// element strides
sw0: u32, sw1: u32, sw2: u32, sw3: u32,
si0: u32, si1: u32, si2: u32, si3: u32,
so0: u32, so1: u32, so2: u32, so3: u32,
// kernel dimensions
KW: u32, KH: u32, IC: u32,
// input dimensions
IW: u32, IH: u32,
// output dimensions
OW: u32, OH: u32, OC_out: u32, N_out: u32,
// stride
s0: u32, s1: u32,
// padding
p0: u32, p1: u32,
// dilation
d0: u32, d1: u32,
};
@group(0) @binding(3)
var<uniform> params: Params;
fn load_weight(idx: u32) -> f32 {
#if defined(WEIGHT_F32)
return weights[idx];
#elif defined(WEIGHT_F16)
return f32(weights[idx]);
#endif
}
fn load_input(idx: u32) -> f32 {
#if defined(INPUT_F32)
return input[idx];
#elif defined(INPUT_F16)
return f32(input[idx]);
#endif
}
fn store_output(idx: u32, val: f32) {
#if defined(OUTPUT_F32)
output[idx] = val;
#elif defined(OUTPUT_F16)
output[idx] = f16(val);
#endif
}
fn ceil_div_u32(x: u32, y: u32) -> u32 {
return (x + y - 1) / y;
}
// returns the first valid kernel index k such that base + k * step >= 0
fn first_valid_k(base: i32, step: u32) -> u32 {
if (base >= 0) {
return 0;
}
return ceil_div_u32(u32(-base), step);
}
// returns the first invalid kernel index k such that base + k * step >= limit so valid k are in [0, end_valid_k)
fn end_valid_k(base: i32, step: u32, limit: u32, k_max: u32) -> u32 {
let remaining = i32(limit) - base;
if (remaining <= 0) {
return 0;
}
return min(k_max, ceil_div_u32(u32(remaining), step));
}
@compute @workgroup_size(WG_SIZE)
fn main(
@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(num_workgroups) num_wg: vec3<u32>
) {
let threads_per_group = u32(WG_SIZE);
let i_out = gid.x + (num_wg.x * threads_per_group) * gid.y;
let n_out = params.OW * params.OH * params.OC_out * params.N_out;
var sum: f32 = 0.0;
if (i_out >= n_out) {
return;
}
// Kernel layout: [KW, KH, IC, ..]
// Input layout: [IW, IH, .., ..]
// Output layout: [OW, OH, OC, N]
var i = i_out;
let n = i / (params.OC_out * params.OH * params.OW);
i = i % (params.OC_out * params.OH * params.OW);
let oc = i / (params.OH * params.OW);
i = i % (params.OH * params.OW);
let oh = i / params.OW;
let ow = i % params.OW;
let ow_base = i32(ow * params.s0) - i32(params.p0);
let oh_base = i32(oh * params.s1) - i32(params.p1);
// clip the valid kernel window once
let kw_begin = first_valid_k(ow_base, params.d0);
let kw_end = end_valid_k(ow_base, params.d0, params.IW, params.KW);
let kh_begin = first_valid_k(oh_base, params.d1);
let kh_end = end_valid_k(oh_base, params.d1, params.IH, params.KH);
// entire receptive field is out of bounds
if (kw_begin >= kw_end || kh_begin >= kh_end) {
let out_idx = params.offset_o + ow * params.so0 + oh * params.so1 + oc * params.so2 + n * params.so3;
store_output(out_idx, 0.0);
return;
}
let weight_oc_base = params.offset_w + oc * params.sw3;
let input_n_base = params.offset_i + n * params.si3;
for (var ic: u32 = 0; ic < params.IC; ic += 1) {
let w_base_ic = ic * params.sw2 + weight_oc_base;
let in_base = ic * params.si2 + input_n_base;
for (var kh: u32 = kh_begin; kh < kh_end; kh += 1) {
let ih = u32(oh_base + i32(kh * params.d1));
let w_row_base = w_base_ic + kh * params.sw1;
let in_row_base = in_base + ih * params.si1;
for (var kw: u32 = kw_begin; kw < kw_end; kw += 1) {
let iw = u32(ow_base + i32(kw * params.d0));
let w_idx = w_row_base + kw * params.sw0;
let in_idx = in_row_base + iw * params.si0;
sum += load_weight(w_idx) * load_input(in_idx);
}
}
}
let out_idx = params.offset_o + ow * params.so0 + oh * params.so1 + oc * params.so2 + n * params.so3;
store_output(out_idx, sum);
}

View File

@ -1,6 +1,8 @@
enable f16;
#define DECLARE_BYTE_LOADERS_SRC
#include "common_decls.tmpl"
#ifdef F32_VEC
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
dst[(dst_base / 4) + offset] = src[(src_base / 4) + offset];
@ -28,10 +30,10 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef Q4_0
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
for (var j: u32 = 0u; j < 4; j++) {
let q_byte_offset = block_byte_base + 2 + j * 4;
let q_packed = load_u32_at(&src, q_byte_offset);
let q_packed = load_u32_at_src(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d;
@ -66,11 +68,11 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef Q5_0
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 22; // Block stride: 22 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let qh_packed = load_u32_at(&src, block_byte_base + 2);
let d = load_f16_as_f32_at_src(block_byte_base);
let qh_packed = load_u32_at_src(block_byte_base + 2);
for (var j: u32 = 0; j < 4; j++) {
let q_byte_offset = block_byte_base + 6 + j * 4;
let q_packed = load_u32_at(&src, q_byte_offset);
let q_packed = load_u32_at_src(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
@ -113,10 +115,10 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef Q8_0
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 34; // Block stride: 34 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
for (var j: u32 = 0u; j < 8u; j++) {
let q_byte_offset = block_byte_base + 2u + j * 4u;
let q_packed = load_u32_at(&src, q_byte_offset);
let q_packed = load_u32_at_src(q_byte_offset);
for (var k: u32 = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f32(q_byte) * d;
@ -162,16 +164,16 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes
// Bytes 108-109: f16 scale 'd'
let d = load_f16_as_f32_at(&src, block_byte_base + 108);
let d = load_f16_as_f32_at_src(block_byte_base + 108);
// Bytes 96-107: 12 bytes of scales (3 u32s)
let kmask1: u32 = 0x03030303;
let kmask2: u32 = 0x0f0f0f0f;
var scale_vals: array<u32, 4>;
scale_vals[0] = load_u32_at(&src, block_byte_base + 96);
scale_vals[1] = load_u32_at(&src, block_byte_base + 100);
scale_vals[2] = load_u32_at(&src, block_byte_base + 104);
scale_vals[0] = load_u32_at_src(block_byte_base + 96);
scale_vals[1] = load_u32_at_src(block_byte_base + 100);
scale_vals[2] = load_u32_at_src(block_byte_base + 104);
var tmp: u32 = scale_vals[2];
scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
@ -182,13 +184,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
// Bytes 0-31: 32 bytes of hmask (8 u32s)
var hmask_vals: array<u32, 8>;
for (var i: u32 = 0; i < 8; i++) {
hmask_vals[i] = load_u32_at(&src, block_byte_base + i * 4);
hmask_vals[i] = load_u32_at_src(block_byte_base + i * 4);
}
// Bytes 32-95: 64 bytes of qs (16 u32s)
var qs_vals: array<u32, 16>;
for (var i: u32 = 0u; i < 16; i++) {
qs_vals[i] = load_u32_at(&src, block_byte_base + 32 + i * 4);
qs_vals[i] = load_u32_at_src(block_byte_base + 32 + i * 4);
}
var dst_i = dst_base + offset * 256;
@ -286,24 +288,24 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 210; // Block stride: 210 bytes
// Bytes 208-209: f16 scale 'd'
let d = load_f16_as_f32_at(&src, block_byte_base + 208);
let d = load_f16_as_f32_at_src(block_byte_base + 208);
// Bytes 0-127: 128 bytes of ql (32 u32s)
var ql_vals: array<u32, 32>;
for (var i: u32 = 0; i < 32; i++) {
ql_vals[i] = load_u32_at(&src, block_byte_base + i * 4);
ql_vals[i] = load_u32_at_src(block_byte_base + i * 4);
}
// Bytes 128-191: 64 bytes of qh (16 u32s)
var qh_vals: array<u32, 16>;
for (var i: u32 = 0; i < 16u; i++) {
qh_vals[i] = load_u32_at(&src, block_byte_base + 128 + i * 4u);
qh_vals[i] = load_u32_at_src(block_byte_base + 128 + i * 4u);
}
// Bytes 192-207: 16 bytes of scales (4 u32s)
var scale_vals: array<u32, 4>;
for (var i: u32 = 0; i < 4; i++) {
scale_vals[i] = load_u32_at(&src, block_byte_base + 192 + i * 4);
scale_vals[i] = load_u32_at_src(block_byte_base + 192 + i * 4);
}
var dst_i = dst_base + offset * 256;
@ -345,13 +347,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef IQ2_XXS
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 66; // Block stride: 66 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
var dst_i = dst_base + offset * 256;
for (var ib: u32 = 0; ib < 32; ib += 4) {
let aux0_offset = block_byte_base + 2 + ib * 2;
let aux1_offset = block_byte_base + 2 + (ib + 2) * 2;
let aux0 = load_u32_at(&src, aux0_offset);
let aux1 = load_u32_at(&src, aux1_offset);
let aux0 = load_u32_at_src(aux0_offset);
let aux1 = load_u32_at_src(aux1_offset);
let db = d * (0.5 + f32(aux1 >> 28)) * 0.25;
for (var l: u32 = 0; l < 4; l++) {
let ig = get_byte(aux0, l) * 8;
@ -373,12 +375,12 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef IQ2_XS
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 74; // Block stride: 74 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
var dst_i = dst_base + offset * 256;
var scale_vals = array<u32, 2>(
load_u32_at(&src, block_byte_base + 66),
load_u32_at(&src, block_byte_base + 70)
load_u32_at_src(block_byte_base + 66),
load_u32_at_src(block_byte_base + 70)
);
for (var ib: u32 = 0; ib < 32; ib += 4) {
@ -389,7 +391,7 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
);
for (var l: u32 = 0; l < 4; l++) {
let qs_offset = block_byte_base + 2 + (ib + l) * 2;
let qs_val = load_u32_at(&src, qs_offset) & 0xFFFF;
let qs_val = load_u32_at_src(qs_offset) & 0xFFFF;
let ig = (qs_val & 511) * 8;
let is = qs_val >> 9;
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
@ -408,21 +410,21 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef IQ2_S
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 82; // Block stride: 82 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
var dst_i = dst_base + offset * 256;
var qs_vals : array<u32, 16>;
for (var i: u32 = 0; i < 16; i++) {
qs_vals[i] = load_u32_at(&src, block_byte_base + 2 + i * 4);
qs_vals[i] = load_u32_at_src(block_byte_base + 2 + i * 4);
}
var qh_vals: array<u32, 2>;
qh_vals[0] = load_u32_at(&src, block_byte_base + 66);
qh_vals[1] = load_u32_at(&src, block_byte_base + 70);
qh_vals[0] = load_u32_at_src(block_byte_base + 66);
qh_vals[1] = load_u32_at_src(block_byte_base + 70);
var scale_vals: array<u32, 2>;
scale_vals[0] = load_u32_at(&src, block_byte_base + 74);
scale_vals[1] = load_u32_at(&src, block_byte_base + 78);
scale_vals[0] = load_u32_at_src(block_byte_base + 74);
scale_vals[1] = load_u32_at_src(block_byte_base + 78);
for (var ib: u32 = 0; ib < 8; ib ++) {
let s = get_byte(scale_vals[ib / 4], ib % 4);
@ -450,16 +452,16 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef IQ3_XXS
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 98; // Block stride: 98 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
var dst_i = dst_base + offset * 256;
for (var ib: u32 = 0; ib < 16; ib += 2) {
let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2;
let sc_sign = load_u32_at(&src, sc_sign_offset);
let sc_sign = load_u32_at_src(sc_sign_offset);
let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5;
for (var l: u32 = 0; l < 4; l++) {
let is = (sc_sign >> (7 * l)) & 127;
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
let ig_val = load_u32_at(&src, block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF;
let ig_val = load_u32_at_src(block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF;
let ig1 = get_byte(ig_val, 0);
let ig2 = get_byte(ig_val, 1);
for (var j: u32 = 0; j < 4; j++) {
@ -480,20 +482,20 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef IQ3_S
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
var dst_i = dst_base + offset * 256;
var qh_vals = array<u32, 2>(
load_u32_at(&src, block_byte_base + 66),
load_u32_at(&src, block_byte_base + 70)
load_u32_at_src(block_byte_base + 66),
load_u32_at_src(block_byte_base + 70)
);
var sign_vals: array<u32, 8>;
for (var i: u32 = 0; i < 8; i++) {
sign_vals[i] = load_u32_at(&src, block_byte_base + 74 + i * 4);
sign_vals[i] = load_u32_at_src(block_byte_base + 74 + i * 4);
}
var scale_vals = load_u32_at(&src, block_byte_base + 106);
var scale_vals = load_u32_at_src(block_byte_base + 106);
for (var ib: u32 = 0; ib < 4; ib++) {
let s = get_byte(scale_vals, ib);
@ -507,7 +509,7 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let sign_w = sign_vals[ib * 2 + k];
for (var l: u32 = 0; l < 4; l++) {
let signs = get_byte(sign_w, l);
let ig_val = load_u32_at(&src, block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF;
let ig_val = load_u32_at_src(block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF;
let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256);
let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256);
for (var j: u32 = 0; j < 4; j++) {
@ -529,13 +531,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef IQ1_S
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 50; // Block stride: 50 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
var dst_i = dst_base + offset * 256;
for (var ib: u32 = 0; ib < 8; ib++) {
let qh = load_u32_at(&src, block_byte_base + 34 + ib * 2) & 0xFFFF;
let qh = load_u32_at_src(block_byte_base + 34 + ib * 2) & 0xFFFF;
let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0);
let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0);
let qs_w = load_u32_at(&src, block_byte_base + 2 + ib * 4);
let qs_w = load_u32_at_src(block_byte_base + 2 + ib * 4);
for (var l: u32 = 0; l < 4; l++) {
let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8;
for (var j: u32 = 0; j < 8; j++) {
@ -596,11 +598,11 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
#ifdef IQ4_NL
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes
let d = load_f16_as_f32_at(&src, block_byte_base);
let d = load_f16_as_f32_at_src(block_byte_base);
var dst_i = dst_base + offset * 32;
var qs: array<u32, 4>;
for (var i: u32 = 0; i < 4; i++) {
qs[i] = load_u32_at(&src, block_byte_base + 2 + i * 4);
qs[i] = load_u32_at_src(block_byte_base + 2 + i * 4);
}
for (var j: u32 = 0; j < 16; j++) {
let qsb = get_byte(qs[j / 4], j % 4);

View File

@ -1,7 +1,9 @@
enable f16;
#define DECLARE_BYTE_LOADERS_SRC0
#include "common_decls.tmpl"
#ifdef FLOAT
const BLOCK_SIZE = 1u;
@ -21,11 +23,11 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef Q4_0
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
let d = load_f16_as_f32_at_src0(block_byte_base);
var sum: f32 = 0.0;
for (var j: u32 = 0; j < 4; j++) {
let q_byte_offset = block_byte_base + 2 + j * 4;
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d;
@ -63,12 +65,12 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef Q5_0
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 22; // Block stride: 22 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
let d = load_f16_as_f32_at_src0(block_byte_base);
var sum: f32 = 0.0;
let qh_packed = load_u32_at(&src0, block_byte_base + 2);
let qh_packed = load_u32_at_src0(block_byte_base + 2);
for (var j: u32 = 0; j < 4; j++) {
let q_byte_offset = block_byte_base + 6 + j * 4;
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10;
@ -110,11 +112,11 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef Q8_0
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 34; // Block stride: 34 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
let d = load_f16_as_f32_at_src0(block_byte_base);
var sum: f32 = 0.0;
for (var j: u32 = 0; j < 8; j++) {
let q_byte_offset = block_byte_base + 2 + j * 4;
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k: u32 = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f32(q_byte) * d;
@ -184,7 +186,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes
// Bytes 108-109: f16 scale 'd'
let d = load_f16_as_f32_at(&src0, block_byte_base + 108);
let d = load_f16_as_f32_at_src0(block_byte_base + 108);
// extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale,
// and 2-bits from the last 4 bytes
@ -192,9 +194,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let kmask1: u32 = 0x03030303;
let kmask2: u32 = 0x0f0f0f0f;
var scale_vals: array<u32, 4>;
scale_vals[0] = load_u32_at(&src0, block_byte_base + 96);
scale_vals[1] = load_u32_at(&src0, block_byte_base + 100);
scale_vals[2] = load_u32_at(&src0, block_byte_base + 104);
scale_vals[0] = load_u32_at_src0(block_byte_base + 96);
scale_vals[1] = load_u32_at_src0(block_byte_base + 100);
scale_vals[2] = load_u32_at_src0(block_byte_base + 104);
var tmp: u32 = scale_vals[2];
scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
@ -205,13 +207,13 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
// Bytes 0-31: 32 bytes of hmask (8 u32s)
var hmask_vals: array<u32, 8>;
for (var i: u32 = 0; i < 8; i++) {
hmask_vals[i] = load_u32_at(&src0, block_byte_base + i * 4);
hmask_vals[i] = load_u32_at_src0(block_byte_base + i * 4);
}
// Bytes 32-95: 64 bytes of qs (16 u32s)
var qs_vals: array<u32, 16>;
for (var i: u32 = 0u; i < 16; i++) {
qs_vals[i] = load_u32_at(&src0, block_byte_base + 32 + i * 4);
qs_vals[i] = load_u32_at_src0(block_byte_base + 32 + i * 4);
}
var sum = 0.0;
@ -313,24 +315,24 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 210; // Block stride: 210 bytes
// Bytes 208-209: f16 scale 'd'
let d = load_f16_as_f32_at(&src0, block_byte_base + 208);
let d = load_f16_as_f32_at_src0(block_byte_base + 208);
// Bytes 0-127: 128 bytes of ql (32 u32s)
var ql_vals: array<u32, 32>;
for (var i: u32 = 0; i < 32; i++) {
ql_vals[i] = load_u32_at(&src0, block_byte_base + i * 4);
ql_vals[i] = load_u32_at_src0(block_byte_base + i * 4);
}
// Bytes 128-191: 64 bytes of qh (16 u32s)
var qh_vals: array<u32, 16>;
for (var i: u32 = 0; i < 16; i++) {
qh_vals[i] = load_u32_at(&src0, block_byte_base + 128 + i * 4);
qh_vals[i] = load_u32_at_src0(block_byte_base + 128 + i * 4);
}
// Bytes 192-207: 16 bytes of scales (4 u32s)
var scale_vals: array<u32, 4>;
for (var i: u32 = 0; i < 4; i++) {
scale_vals[i] = load_u32_at(&src0, block_byte_base + 192 + i * 4);
scale_vals[i] = load_u32_at_src0(block_byte_base + 192 + i * 4);
}
var sum = 0.0;
@ -374,14 +376,14 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef IQ2_XXS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 66; // Block stride: 66 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
let d = load_f16_as_f32_at_src0(block_byte_base);
var src1_i = src1_idx_base + offset * 256;
var sum = 0.0;
for (var ib: u32 = 0; ib < 32; ib += 4) {
let aux0_offset = block_byte_base + 2 + ib * 2;
let aux1_offset = block_byte_base + 2 + (ib + 2) * 2;
let aux0 = load_u32_at(&src0, aux0_offset);
let aux1 = load_u32_at(&src0, aux1_offset);
let aux0 = load_u32_at_src0(aux0_offset);
let aux1 = load_u32_at_src0(aux1_offset);
let db = d * (0.5 + f32(aux1 >> 28)) * 0.25;
for (var l: u32 = 0; l < 4; l++) {
let ig = get_byte(aux0, l) * 8;
@ -402,12 +404,12 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef IQ2_XS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 74; // Block stride: 74 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
let d = load_f16_as_f32_at_src0(block_byte_base);
var src1_i = src1_idx_base + offset * 256;
var scale_vals = array<u32, 2>(
load_u32_at(&src0, block_byte_base + 66),
load_u32_at(&src0, block_byte_base + 70)
load_u32_at_src0(block_byte_base + 66),
load_u32_at_src0(block_byte_base + 70)
);
var sum = 0.0;
@ -419,7 +421,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
);
for (var l: u32 = 0; l < 4; l++) {
let qs_offset = block_byte_base + 2 + (ib + l) * 2;
let qs_val = load_u32_at(&src0, qs_offset) & 0xFFFF;
let qs_val = load_u32_at_src0(qs_offset) & 0xFFFF;
let ig = (qs_val & 511) * 8;
let is = qs_val >> 9;
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
@ -439,21 +441,21 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef IQ2_S
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 82; // Block stride: 82 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
let d = load_f16_as_f32_at_src0(block_byte_base);
var src1_i = src1_idx_base + offset * 256;
var qs_vals : array<u32, 16>;
for (var i: u32 = 0; i < 16; i++) {
qs_vals[i] = load_u32_at(&src0, block_byte_base + 2 + i * 4);
qs_vals[i] = load_u32_at_src0(block_byte_base + 2 + i * 4);
}
var qh_vals: array<u32, 2>;
qh_vals[0] = load_u32_at(&src0, block_byte_base + 66);
qh_vals[1] = load_u32_at(&src0, block_byte_base + 70);
qh_vals[0] = load_u32_at_src0(block_byte_base + 66);
qh_vals[1] = load_u32_at_src0(block_byte_base + 70);
var scale_vals: array<u32, 2>;
scale_vals[0] = load_u32_at(&src0, block_byte_base + 74);
scale_vals[1] = load_u32_at(&src0, block_byte_base + 78);
scale_vals[0] = load_u32_at_src0(block_byte_base + 74);
scale_vals[1] = load_u32_at_src0(block_byte_base + 78);
var sum = 0.0;
for (var ib: u32 = 0; ib < 8; ib ++) {
@ -483,17 +485,17 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef IQ3_XXS
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 98; // Block stride: 98 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
let d = load_f16_as_f32_at_src0(block_byte_base);
var src1_i = src1_idx_base + offset * 256;
var sum = 0.0;
for (var ib: u32 = 0; ib < 16; ib += 2) {
let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2;
let sc_sign = load_u32_at(&src0, sc_sign_offset);
let sc_sign = load_u32_at_src0(sc_sign_offset);
let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5;
for (var l: u32 = 0; l < 4; l++) {
let is = (sc_sign >> (7 * l)) & 127;
let signs = get_byte(ksigns_iq2xs[is / 4], is % 4);
let ig_val = load_u32_at(&src0, block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF;
let ig_val = load_u32_at_src0(block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF;
let ig1 = get_byte(ig_val, 0);
let ig2 = get_byte(ig_val, 1);
for (var j: u32 = 0; j < 4; j++) {
@ -515,20 +517,20 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef IQ3_S
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
let d = load_f16_as_f32_at_src0(block_byte_base);
var src1_i = src1_idx_base + offset * 256;
var qh_vals = array<u32, 2>(
load_u32_at(&src0, block_byte_base + 66),
load_u32_at(&src0, block_byte_base + 70)
load_u32_at_src0(block_byte_base + 66),
load_u32_at_src0(block_byte_base + 70)
);
var sign_vals: array<u32, 8>;
for (var i: u32 = 0; i < 8; i++) {
sign_vals[i] = load_u32_at(&src0, block_byte_base + 74 + i * 4);
sign_vals[i] = load_u32_at_src0(block_byte_base + 74 + i * 4);
}
var scale_vals = load_u32_at(&src0, block_byte_base + 106);
var scale_vals = load_u32_at_src0(block_byte_base + 106);
var sum = 0.0;
for (var ib: u32 = 0; ib < 4; ib++) {
@ -543,7 +545,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let sign_w = sign_vals[ib * 2 + k];
for (var l: u32 = 0; l < 4; l++) {
let signs = get_byte(sign_w, l);
let ig_val = load_u32_at(&src0, block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF;
let ig_val = load_u32_at_src0(block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF;
let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256);
let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256);
for (var j: u32 = 0; j < 4; j++) {
@ -566,14 +568,14 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef IQ1_S
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 50; // Block stride: 50 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
let d = load_f16_as_f32_at_src0(block_byte_base);
var src1_i = src1_idx_base + offset * 256;
var sum = 0.0;
for (var ib: u32 = 0; ib < 8; ib++) {
let qh = load_u32_at(&src0, block_byte_base + 34 + ib * 2) & 0xFFFF;
let qh = load_u32_at_src0(block_byte_base + 34 + ib * 2) & 0xFFFF;
let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0);
let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0);
let qs_w = load_u32_at(&src0, block_byte_base + 2 + ib * 4);
let qs_w = load_u32_at_src0(block_byte_base + 2 + ib * 4);
for (var l: u32 = 0; l < 4; l++) {
let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8;
for (var j: u32 = 0; j < 8; j++) {
@ -638,12 +640,12 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
#ifdef IQ4_NL
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes
let d = load_f16_as_f32_at(&src0, block_byte_base);
let d = load_f16_as_f32_at_src0(block_byte_base);
var src1_i = src1_idx_base + offset * 32;
var sum = 0.0;
var qs: array<u32, 4>;
for (var i: u32 = 0; i < 4; i++) {
qs[i] = load_u32_at(&src0, block_byte_base + 2 + i * 4);
qs[i] = load_u32_at_src0(block_byte_base + 2 + i * 4);
}
for (var j: u32 = 0; j < 16; j++) {
let qsb = get_byte(qs[j / 4], j % 4);

View File

@ -84,11 +84,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at(&src0, block_byte_base);
let d = load_f16_at_src0(block_byte_base);
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
@ -125,12 +125,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at(&src0, block_byte_base);
let m = load_f16_at(&src0, block_byte_base + 2u);
let d = load_f16_at_src0(block_byte_base);
let m = load_f16_at_src0(block_byte_base + 2u);
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_lo = f16(q_byte & 0xF) * d + m;
@ -171,12 +171,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at(&src0, block_byte_base);
let qh_packed = load_u32_at(&src0, block_byte_base + 2u);
let d = load_f16_at_src0(block_byte_base);
let qh_packed = load_u32_at_src0(block_byte_base + 2u);
for (var j = 0u; j < 2; j++) {
let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
let j_adjusted = j + (block_offset / 2u);
@ -225,14 +225,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at(&src0, block_byte_base);
let m = load_f16_at(&src0, block_byte_base + 2u);
let qh_packed = load_u32_at(&src0, block_byte_base + 4u);
let d = load_f16_at_src0(block_byte_base);
let m = load_f16_at_src0(block_byte_base + 2u);
let qh_packed = load_u32_at_src0(block_byte_base + 4u);
for (var j = 0u; j < 2; j++) {
let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
let j_adjusted = j + (block_offset / 2u);
@ -277,11 +277,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at(&src0, block_byte_base);
let d = load_f16_at_src0(block_byte_base);
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
@ -317,12 +317,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at(&src0, block_byte_base);
let m = load_f16_at(&src0, block_byte_base + 2u);
let d = load_f16_at_src0(block_byte_base);
let m = load_f16_at_src0(block_byte_base + 2u);
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
@ -359,8 +359,8 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at(&src0, block_byte_base + 80u);
let dmin = load_f16_at(&src0, block_byte_base + 82u);
let d = load_f16_at_src0(block_byte_base + 80u);
let dmin = load_f16_at_src0(block_byte_base + 82u);
// Decode the element at position k_in_block
let block_of_32 = k_in_block / 32u;
@ -373,14 +373,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let is = k_in_block / 16u;
let sc_packed = load_u32_at(&src0, block_byte_base + 4u * (is / 4u));
let sc_packed = load_u32_at_src0(block_byte_base + 4u * (is / 4u));
let sc = get_byte(sc_packed, is % 4u);
let dl = d * f16(sc & 0xFu);
let ml = dmin * f16(sc >> 4u);
let q_idx = q_b_idx + k + l;
let q_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (q_idx / 4u));
let q_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (q_idx / 4u));
let q_byte = get_byte(q_packed, q_idx % 4u);
let qs_val = (q_byte >> shift) & 3u;
@ -413,7 +413,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at(&src0, block_byte_base + 108u);
let d = load_f16_at_src0(block_byte_base + 108u);
// Load and unpack scales
let kmask1: u32 = 0x03030303u;
@ -421,7 +421,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
var scale_vals: array<u32, 4>;
for (var i: u32 = 0u; i < 4u; i++) {
scale_vals[i] = load_u32_at(&src0, block_byte_base + 96u + 4u * i);
scale_vals[i] = load_u32_at_src0(block_byte_base + 96u + 4u * i);
}
var tmp: u32 = scale_vals[2];
@ -433,12 +433,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
// Load hmask and qs arrays
var hmask_vals: array<u32, 8>;
for (var i: u32 = 0u; i < 8u; i++) {
hmask_vals[i] = load_u32_at(&src0, block_byte_base + 4u * i);
hmask_vals[i] = load_u32_at_src0(block_byte_base + 4u * i);
}
var qs_vals: array<u32, 16>;
for (var i: u32 = 0u; i < 16u; i++) {
qs_vals[i] = load_u32_at(&src0, block_byte_base + 32u + 4u * i);
qs_vals[i] = load_u32_at_src0(block_byte_base + 32u + 4u * i);
}
let half = k_in_block / 128u; // 0 or 1
@ -499,8 +499,8 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at(&src0, block_byte_base);
let dmin = load_f16_at(&src0, block_byte_base + 2u);
let d = load_f16_at_src0(block_byte_base);
let dmin = load_f16_at_src0(block_byte_base + 2u);
// Map k_in_block to loop structure:
// Outer loop over 64-element groups (alternating q_b_idx)
@ -520,14 +520,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let scale_base = block_byte_base + 4u;
if (is < 4u) {
let sc_byte = get_byte(load_u32_at(&src0, scale_base), is % 4u);
let min_byte = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u);
let sc_byte = get_byte(load_u32_at_src0(scale_base), is % 4u);
let min_byte = get_byte(load_u32_at_src0(scale_base + 4), is % 4u);
sc = sc_byte & 63u;
mn = min_byte & 63u;
} else {
let sc_min_lo = get_byte(load_u32_at(&src0, scale_base + 8), (is + 4u) % 4u);
let sc_hi = get_byte(load_u32_at(&src0, scale_base), (is - 4u) % 4u);
let min_hi = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u);
let sc_min_lo = get_byte(load_u32_at_src0(scale_base + 8), (is + 4u) % 4u);
let sc_hi = get_byte(load_u32_at_src0(scale_base), (is - 4u) % 4u);
let min_hi = get_byte(load_u32_at_src0(scale_base + 4), is % 4u);
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
@ -537,7 +537,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let ml = dmin * f16(mn);
let q_idx = q_b_idx + l;
let q_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (q_idx / 4u));
let q_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (q_idx / 4u));
let q_byte = get_byte(q_packed, q_idx % 4u);
let qs_val = (q_byte >> shift) & 0xFu;
@ -571,8 +571,8 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at(&src0, block_byte_base);
let dmin = load_f16_at(&src0, block_byte_base + 2u);
let d = load_f16_at_src0(block_byte_base);
let dmin = load_f16_at_src0(block_byte_base + 2u);
// The original loop processes elements in groups of 64
@ -597,14 +597,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let scale_base = block_byte_base + 4u;
if (is < 4u) {
let sc_byte = get_byte(load_u32_at(&src0, scale_base), is % 4u);
let min_byte = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u);
let sc_byte = get_byte(load_u32_at_src0(scale_base), is % 4u);
let min_byte = get_byte(load_u32_at_src0(scale_base + 4), is % 4u);
sc = sc_byte & 63u;
mn = min_byte & 63u;
} else {
let sc_min_lo = get_byte(load_u32_at(&src0, scale_base + 8), (is + 4u) % 4u);
let sc_hi = get_byte(load_u32_at(&src0, scale_base), (is - 4u) % 4u);
let min_hi = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u);
let sc_min_lo = get_byte(load_u32_at_src0(scale_base + 8), (is + 4u) % 4u);
let sc_hi = get_byte(load_u32_at_src0(scale_base), (is - 4u) % 4u);
let min_hi = get_byte(load_u32_at_src0(scale_base + 4), is % 4u);
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
@ -614,11 +614,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let ml = dmin * f16(mn);
let q_idx = q_b_idx + l;
let q_packed = load_u32_at(&src0, block_byte_base + 48u + 4u * (q_idx / 4u));
let q_packed = load_u32_at_src0(block_byte_base + 48u + 4u * (q_idx / 4u));
let q_byte = get_byte(q_packed, q_idx % 4u);
let qh_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (l / 4u));
let qh_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (l / 4u));
let qh_byte = get_byte(qh_packed, l % 4u);
@ -666,17 +666,17 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
// Load only ql13 word needed
let ql13_flat = ql_b_idx + l;
let ql13 = load_u32_at(&src0, block_byte_base + ql13_flat);
let ql13 = load_u32_at_src0(block_byte_base + ql13_flat);
let ql13_b = get_byte(ql13, 0u);
// Load only ql24 word needed
let ql24_flat = ql_b_idx + l + 32u;
let ql24 = load_u32_at(&src0, block_byte_base + ql24_flat);
let ql24 = load_u32_at_src0(block_byte_base + ql24_flat);
let ql24_b = get_byte(ql24, 0u);
// Load only qh word needed
let qh_flat = qh_b_idx + l;
let qh = load_u32_at(&src0, block_byte_base + 128u + qh_flat);
let qh = load_u32_at_src0(block_byte_base + 128u + qh_flat);
let qh_b = get_byte(qh, 0u);
let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0);
@ -687,10 +687,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
// Load only the scale word needed
let is = l / 16u;
let sc_idx = sc_b_idx + is + quarter * 2u;
let sc = load_u32_at(&src0, block_byte_base + 192u + sc_idx);
let sc = load_u32_at_src0(block_byte_base + 192u + sc_idx);
let sc_val = get_byte_i32(sc, 0u);
let d = load_f16_at(&src0, block_byte_base + 208u);
let d = load_f16_at_src0(block_byte_base + 208u);
var q_val: f16;
if (quarter == 0u) {

View File

@ -1,6 +1,8 @@
enable f16;
#define DECLARE_BYTE_LOADERS_SRC0
#include "common_decls.tmpl"
#include "mul_mat_decls.tmpl"
#ifdef VEC

View File

@ -1,6 +1,8 @@
enable f16;
#define DECLARE_BYTE_LOADERS_SRC0
#include "common_decls.tmpl"
#include "mul_mat_decls.tmpl"
#ifdef VEC

View File

@ -3,7 +3,9 @@ enable f16;
enable subgroups;
enable chromium_experimental_subgroup_matrix;
#define DECLARE_BYTE_LOADERS_SRC0
#include "common_decls.tmpl"
#include "mul_mat_decls.tmpl"
// TODO: this shader path does not work with some models like qwen2.5 on Metal devices, f16 accumulation causes NaNs.
@ -196,4 +198,3 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
}
}
}

View File

@ -1,7 +1,9 @@
enable f16;
#define DECLARE_BYTE_LOADERS_SRC0
#include "common_decls.tmpl"
#ifdef VEC
#define VEC_SIZE 4
@ -65,10 +67,10 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(load_f16_at(&src0, block_byte_base));
let d = f32(load_f16_at_src0(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d;
@ -98,11 +100,11 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(load_f16_at(&src0, block_byte_base));
let m = f32(load_f16_at(&src0, block_byte_base + 2u));
let d = f32(load_f16_at_src0(block_byte_base));
let m = f32(load_f16_at_src0(block_byte_base + 2u));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
@ -132,12 +134,12 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(load_f16_at(&src0, block_byte_base));
let qh_packed = load_u32_at(&src0, block_byte_base + 2u);
let d = f32(load_f16_at_src0(block_byte_base));
let qh_packed = load_u32_at_src0(block_byte_base + 2u);
for (var j = 0u; j < 2; j++) {
let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
let j_adjusted = j + (block_offset / 2u);
@ -176,13 +178,13 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(load_f16_at(&src0, block_byte_base));
let m = load_f16_at(&src0, block_byte_base + 2u);
let qh_packed = load_u32_at(&src0, block_byte_base + 4u);
let d = f32(load_f16_at_src0(block_byte_base));
let m = load_f16_at_src0(block_byte_base + 2u);
let qh_packed = load_u32_at_src0(block_byte_base + 4u);
for (var j = 0u; j < 2; j++) {
let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
let j_adjusted = j + (block_offset / 2u);
@ -221,11 +223,11 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(load_f16_at(&src0, block_byte_base));
let d = f32(load_f16_at_src0(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f32(q_byte) * d;
@ -254,12 +256,12 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES;
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
let d = f32(load_f16_at(&src0, block_byte_base));
let m = load_f16_at(&src0, block_byte_base + 2u);
let d = f32(load_f16_at_src0(block_byte_base));
let m = load_f16_at_src0(block_byte_base + 2u);
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j);
let q_packed = load_u32_at(&src0, q_byte_offset);
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k: u32 = 0; k < 4; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f32(q_byte) * d + f32(m);
@ -309,13 +311,13 @@ fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
for (var i = ix; i < nb; i += 2u) {
let bbase = (idx_base + k_block_start + i) * BLOCK_SIZE_BYTES;
let d = f32(load_f16_at(&src0, bbase + 208u));
let d = f32(load_f16_at_src0(bbase + 208u));
let ql1_u32 = load_u32_at(&src0, bbase + q_offset_l);
let ql2_u32 = load_u32_at(&src0, bbase + q_offset_l + 32u);
let qh_u32 = load_u32_at(&src0, bbase + 128u + q_offset_h);
let sc_u32_0 = load_u32_at(&src0, bbase + sc_base_byte);
let sc_u32_1 = load_u32_at(&src0, bbase + sc_base_byte + 4u);
let ql1_u32 = load_u32_at_src0(bbase + q_offset_l);
let ql2_u32 = load_u32_at_src0(bbase + q_offset_l + 32u);
let qh_u32 = load_u32_at_src0(bbase + 128u + q_offset_h);
let sc_u32_0 = load_u32_at_src0(bbase + sc_base_byte);
let sc_u32_1 = load_u32_at_src0(bbase + sc_base_byte + 4u);
let sc0 = sbyte_of(sc_u32_0, sc_byte_pos);
let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u);

View File

@ -147,15 +147,12 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
-9.010913, 9.010913)));
#endif
#ifdef XIELU
let val = f32(src[params.offset_src + src_idx]);
let res =
select(((exp(min(src[params.offset_src + src_idx], TYPE(params.eps))) - 1.0) -
src[params.offset_src + src_idx]) *
TYPE(params.alpha_n) +
TYPE(params.beta) * src[params.offset_src + src_idx],
TYPE(params.alpha_p) * src[params.offset_src + src_idx] *
src[params.offset_src + src_idx] +
TYPE(params.beta) * src[params.offset_src + src_idx],
src[params.offset_src + src_idx] > 0.0);
TYPE(select(
((exp(min(val, params.eps)) - 1.0) - val) * params.alpha_n + params.beta * val,
params.alpha_p * val * val + params.beta * val,
val > 0.0));
#endif
#ifdef SOFTPLUS
let src_f32 = f32(src[params.offset_src + src_idx]);

View File

@ -0,0 +1,161 @@
{%- macro render_content(content, num_img_tokens, num_video_frames) -%}
{%- if content is string -%}
{{- content -}}
{%- elif content is sequence -%}
{%- set ns = namespace(out="", prev_was_text=false) -%}
{%- for item in content -%}
{%- set item_type = item.get("type") -%}
{%- if item_type == "text" or item.get("text") is not none -%}
{%- set text = item.get("text", "") -%}
{%- if text -%}
{%- if ns.prev_was_text -%}
{%- set ns.out = ns.out ~ " " -%}
{%- endif -%}
{%- set ns.out = ns.out ~ text -%}
{%- endif -%}
{%- set ns.prev_was_text = text != "" -%}
{%- elif item_type in ["image", "image_url"] or item.get("image") is not none or item.get("image_url") is not none -%}
{%- set ns.out = ns.out ~ "<image>" ~ ("<REKA_IMG_TOKEN>" * num_img_tokens) ~ "</image>" -%}
{%- set ns.prev_was_text = false -%}
{%- elif item_type in ["video", "video_url"] or item.get("video") is not none or item.get("video_url") is not none -%}
{%- set repeat_tokens = num_img_tokens * num_video_frames -%}
{%- set ns.out = ns.out ~ "<video>" ~ ("<REKA_IMG_TOKEN>" * repeat_tokens) ~ "</video>" -%}
{%- set ns.prev_was_text = false -%}
{%- endif -%}
{%- endfor -%}
{{- ns.out -}}
{%- endif -%}
{%- endmacro -%}
{%- set ns = namespace(out="", last_query_index=messages|length - 1) -%}
{%- for msg in messages[::-1] -%}
{%- set idx = messages|length - 1 - loop.index0 -%}
{%- if msg.get("role") == "user" -%}
{%- set content = msg.get("content", "") -%}
{%- if not (content is string and content.startswith("<tool_response>") and content.endswith("</tool_response>")) -%}
{%- set ns.last_query_index = idx -%}
{%- break -%}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
{%- set last_query_index = ns.last_query_index -%}
{%- set num_img_tokens = num_img_tokens | default(64, true) | int -%}
{%- set num_video_frames = num_video_frames | default(6, true) | int -%}
{%- set start_idx = 0 -%}
{%- set system_text = "" -%}
{%- if messages|length > 0 and messages[0].get("role") in ["system", "developer"] -%}
{%- set system_text = render_content(messages[0].get("content", ""), num_img_tokens, num_video_frames) -%}
{%- set start_idx = 1 -%}
{%- endif -%}
{%- if tools or system_text -%}
{%- set preamble_ns = namespace(text="") -%}
{%- if system_text -%}
{%- set preamble_ns.text = "system: " ~ system_text -%}
{%- endif -%}
{%- if tools -%}
{%- if preamble_ns.text -%}
{%- set preamble_ns.text = preamble_ns.text ~ "\n\n" -%}
{%- else -%}
{%- set preamble_ns.text = "system: " -%}
{%- endif -%}
{%- set preamble_ns.text = preamble_ns.text
~ "# Tools\n\n"
~ "You may call one or more functions to assist with the user query.\n\n"
~ "You are provided with function signatures within <tools></tools> XML tags:\n"
~ "<tools>" -%}
{%- for tool in tools -%}
{%- set preamble_ns.text = preamble_ns.text ~ "\n" ~ (tool | tojson(ensure_ascii=True)) -%}
{%- endfor -%}
{%- set preamble_ns.text = preamble_ns.text
~ "\n</tools>\n\n"
~ "For each function call, return a json object with function name and arguments "
~ "within <tool_call></tool_call> XML tags:\n"
~ "<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>" -%}
{%- endif -%}
{%- set ns.out = ns.out ~ preamble_ns.text ~ "\n\n<sep>" -%}
{%- endif -%}
{%- for idx in range(start_idx, messages|length) -%}
{%- set message = messages[idx] -%}
{%- set role = message.get("role") -%}
{%- set content = message.get("content") -%}
{%- if role == "user" -%}
{%- set prefix_ns = namespace(value="human: ") -%}
{%- if content is sequence and content is not string -%}
{%- for item in content -%}
{%- if item.get("type") == "text" or item.get("text") is not none -%}
{%- set text = item.get("text", "") -%}
{%- if text -%}
{%- break -%}
{%- endif -%}
{%- elif item.get("type") in ["image", "image_url", "video", "video_url"] -%}
{%- set prefix_ns.value = "human:" -%}
{%- break -%}
{%- endif -%}
{%- endfor -%}
{%- endif -%}
{%- set ns.out = ns.out ~ prefix_ns.value ~ render_content(content, num_img_tokens, num_video_frames) ~ "<sep>" -%}
{%- elif role == "assistant" -%}
{%- set tool_calls = message.get("tool_calls") -%}
{%- set content_text = render_content(content, num_img_tokens, num_video_frames) -%}
{%- set reasoning_text = "" -%}
{%- if message.get("reasoning_content") is string -%}
{%- set reasoning_text = message.get("reasoning_content") -%}
{%- elif "</think>" in content_text -%}
{%- set reasoning_text = content_text.split("</think>", 1)[0].rstrip("\n").split("<think>")[-1].lstrip("\n") -%}
{%- set content_text = content_text.split("</think>", 1)[1].lstrip("\n") -%}
{%- endif -%}
{%- set ns.out = ns.out ~ "assistant: " -%}
{%- set include_thinking = enable_thinking is true
and idx > last_query_index
and (idx == messages|length - 1 or reasoning_text)
-%}
{%- if include_thinking -%}
{%- set ns.out = ns.out ~ "<think>\n" ~ (reasoning_text.strip() ) ~ "\n</think>\n\n" -%}
{%- endif -%}
{%- set ns.out = ns.out ~ content_text -%}
{%- if tool_calls -%}
{%- if content_text and not ns.out.endswith("\n") -%}
{%- set ns.out = ns.out ~ "\n" -%}
{%- endif -%}
{%- for tool_call in tool_calls -%}
{%- if tool_call.get("function") is not none -%}
{%- set tool_call = tool_call.get("function") -%}
{%- endif -%}
{%- set arguments = tool_call.get("arguments", {}) -%}
{%- if arguments is string -%}
{%- set arguments_json = arguments -%}
{%- elif arguments is mapping -%}
{%- set arguments_json = arguments | tojson(ensure_ascii=True) -%}
{%- else -%}
{%- set arguments_json = arguments | tojson(ensure_ascii=True) -%}
{%- endif -%}
{%- set ns.out = ns.out
~ "<tool_call>\n"
~ "{\"name\": \"" ~ tool_call.get("name", "") ~ "\", \"arguments\": "
~ arguments_json
~ "}\n</tool_call>" -%}
{%- endfor -%}
{%- endif -%}
{%- if not (continue_final_message and idx == messages|length - 1) -%}
{%- set ns.out = ns.out ~ "\n\n<sep>" -%}
{%- endif -%}
{%- elif role == "tool" -%}
{%- if idx == start_idx or messages[idx - 1].get("role") != "tool" -%}
{%- set ns.out = ns.out ~ "human: " -%}
{%- endif -%}
{%- set response_text = render_content(content, num_img_tokens, num_video_frames) -%}
{%- set ns.out = ns.out ~ "<tool_response>\n" ~ response_text ~ "\n</tool_response>" -%}
{%- if idx == messages|length - 1 or messages[idx + 1].get("role") != "tool" -%}
{%- set ns.out = ns.out ~ "<sep>" -%}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
{%- if add_generation_prompt
and (messages|length == 0 or messages[-1].get("role") != "assistant")
-%}
{%- if enable_thinking is true -%}
{%- set ns.out = ns.out ~ "assistant: <think>\n" -%}
{%- else -%}
{%- set ns.out = ns.out ~ "assistant:" -%}
{%- endif -%}
{%- endif -%}
{{- ns.out -}}

View File

@ -2164,7 +2164,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
tst.test(
"<tool_call>\n"
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}"
"</tool_call>")
.tools({ special_function_tool })
.expect(message_assist_call)
@ -2172,7 +2172,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
tst.test(
"Hello, world!\nWhat's up?<tool_call>\n"
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}"
"</tool_call>")
.tools({ special_function_tool })
.expect(message_assist_call_content)
@ -3329,6 +3329,92 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
.run();
}
// Reka-Edge tests - uses native JSON format with per-call wrapper
{
auto tst = peg_tester("models/templates/Reka-Edge.jinja", detailed_debug);
// Basic content only
tst.test("Hello, world!\nWhat's up?").enable_thinking(false).expect(message_assist).run();
// Single tool call without reasoning
tst.test("<tool_call>\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}</tool_call>")
.enable_thinking(false)
.tools({ special_function_tool })
.expect(message_assist_call)
.run();
// Tool call with string argument
tst.test("<tool_call>\n{\"name\": \"get_time\", \"arguments\": {\"city\": \"XYZCITY\"}}</tool_call>")
.enable_thinking(false)
.tools({ get_time_tool })
.expect(message_with_tool_calls("get_time", "{\"city\":\"XYZCITY\"}"))
.run();
// Tool call with reasoning (enable_thinking=true)
tst.test("I'm\nthinking</think><tool_call>\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}</tool_call>")
.enable_thinking(true)
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.tools({ special_function_tool })
.expect(message_assist_call_thoughts)
.run();
// Multiple tool calls (parallel)
tst.test(
"<tool_call>\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}</tool_call>"
"<tool_call>\n{\"name\": \"special_function_with_opt\", \"arguments\": {\"arg1\": 1, \"arg2\": 2}}</tool_call>"
)
.enable_thinking(false)
.parallel_tool_calls(true)
.tools({
special_function_tool, special_function_tool_with_optional_param
})
.expect_tool_calls({
{ "special_function", R"({"arg1": 1})", {} },
{ "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} },
})
.run();
// Tool call with reasoning and content
tst.test("I need to call a function</think>"
"Let me check the time.<tool_call>\n{\"name\": \"get_time\", \"arguments\": {\"city\": \"XYZCITY\"}}</tool_call>")
.enable_thinking(true)
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.tools({ get_time_tool })
.expect(message_with_reasoning_content_and_multiple_tool_calls(
"I need to call a function", "Let me check the time.", { { "get_time", "{\"city\":\"XYZCITY\"}" } }
))
.run();
// Partial tool call (streaming)
tst.test("<tool_call>\n{\"name\": \"special_function\", \"arguments\": {\"arg1\":")
.tools({ special_function_tool })
.enable_thinking(false)
.is_partial(true)
.expect(simple_assist_msg("", "", "special_function", "{\"arg1\": "))
.run();
// Tool call with empty arguments
tst.test("<tool_call>\n{\"name\": \"empty_args\", \"arguments\": {}}</tool_call>")
.enable_thinking(false)
.tools({ empty_args_tool })
.expect(simple_assist_msg("", "", "empty_args", "{}"))
.run();
// fake tool call marker in reasoning
tst.test(
"Let me think about <tool_call>\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 2}}</tool_call> hmm</think>"
"<tool_call>\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}</tool_call>")
.enable_thinking(true)
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.tools({ special_function_tool })
.expect_reasoning("Let me think about <tool_call>\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 2}}</tool_call> hmm")
.expect_tool_calls({
{ "special_function", R"({"arg1": 1})", {} },
})
.run();
}
// Apertus-8B-Instruct tests - FUNC_NAME_AS_KEY format
// Format: <|tools_prefix|>[{"function_name": {...arguments...}}]<|tools_suffix|>
{

View File

@ -95,6 +95,12 @@ $ bin/rpc-server -c
By default, the cache is stored in the `$HOME/.cache/llama.cpp/rpc` directory and can be controlled via the `LLAMA_CACHE` environment variable.
### RDMA transport
On Linux systems with RoCEv2-capable NICs (e.g. Mellanox ConnectX), the RPC backend can use RDMA instead of TCP for lower latency and higher throughput. The transport is negotiated automatically -- no changes to command-line usage are required.
RDMA is enabled by default when `libibverbs` is found at build time.
### Troubleshooting
Use the `GGML_RPC_DEBUG` environment variable to enable debug messages from `rpc-server`: