Merge branch 'ggml-org:master' into power-law-sampler

This commit is contained in:
ddh0 2025-12-21 14:41:16 -06:00 committed by GitHub
commit 55ad4a89d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 291 additions and 82 deletions

View File

@ -1,4 +1,4 @@
# Copilot Instructions for llama.cpp
# Instructions for llama.cpp
## Repository Overview
@ -12,6 +12,24 @@ llama.cpp is a large-scale C/C++ project for efficient LLM (Large Language Model
- **Backends supported**: CPU (AVX/NEON/RVV optimized), CUDA, Metal, Vulkan, SYCL, ROCm, MUSA
- **License**: MIT
## Disclose AI Usage
It is crucial to remind contributors that the project mandates disclosure of any AI usage in pull requests. This requirement stems from the potential for AI-generated code to include suboptimal optimizations and hidden bugs, owing to the inherent overconfidence in AI outputs.
When generating significant portions of code, address this by:
- Informing the user that AI-generated content may be rejected by maintainers.
- Clearly marking AI-generated code in commit messages and comments.
- Example of commit message: `[AI] Fix a race condition in ...`
- Example of code comment: `// [AI] spawn a new thread ...`
These measures apply to:
- Changes resulting in large portions of code or complex logic.
- Modifications or additions to public APIs in `llama.h`, `ggml.h`, or `mtmd.h`.
- Backend-related changes, such as those involving CPU, CUDA, Metal, Vulkan, etc.
- Modifications to `tools/server`.
Note: These measures can be omitted for small fixes or trivial changes.
## Build Instructions
### Prerequisites
@ -251,6 +269,7 @@ Primary tools:
- **Cross-platform compatibility**: Test on Linux, macOS, Windows when possible
- **Performance focus**: This is a performance-critical inference library
- **API stability**: Changes to `include/llama.h` require careful consideration
- **Disclose AI Usage**: Refer to the "Disclose AI Usage" earlier in this document
### Git Workflow
- Always create feature branches from `master`

View File

@ -55,7 +55,7 @@ auto parser = build_chat_peg_native_parser([&](common_chat_peg_native_builder &
```
For a more complete example, see `test_example_native()` in
[tests/test-chat-peg-parser.cpp](tests/test-chat-peg-parser.cpp).
[tests/test-chat-peg-parser.cpp](/tests/test-chat-peg-parser.cpp).
## Parsers/Combinators
@ -175,7 +175,7 @@ Most model output can be placed in one of the following categories:
(Qwen3-Coder, MiniMax M2) or pseudo-function calls (LFM2)
To provide broad coverage,
[`common/chat-peg-parser.h`](common/chat-peg-parser.h) contains builders and
[`common/chat-peg-parser.h`](/common/chat-peg-parser.h) contains builders and
mappers that help create parsers and visitors/extractors for these types. They
require parsers to tag nodes to conform to an AST "shape". This normalization
makes it easy to extract information and generalize parsing.

View File

@ -3076,8 +3076,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx + 9];
ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true;
}
}
@ -3085,7 +3088,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx + 4];
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true;
}
}
@ -3094,8 +3101,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
ggml_tensor * weights = cgraph->nodes[node_idx + 5];
ggml_tensor * get_rows = cgraph->nodes[node_idx + 2];
ggml_tensor * argsort = cgraph->nodes[node_idx + 0];
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true;
}
}

View File

@ -268,7 +268,23 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
}
}
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp) {
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
const ggml_tensor * weights,
const ggml_tensor * get_rows,
const ggml_tensor * argsort,
const ggml_tensor * clamp,
int n_expert) {
ggml_tensor * probs = get_rows->src[0];
if (probs->op != GGML_OP_RESHAPE) {
return false;
}
probs = probs->src[0];
ggml_tensor * selection_probs = argsort->src[0];
if (probs != selection_probs) {
return false;
}
float scale = 1.0f;
float max_bias = 0.0f;
@ -288,7 +304,6 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
return false;
}
const int n_expert = softmax->ne[0];
// n_expert must be a power of 2
if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) {
return false;

View File

@ -11,6 +11,11 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const bool delayed_softmax = false,
ggml_tensor * weight_clamp = nullptr);
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp = nullptr);
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
const ggml_tensor * weights,
const ggml_tensor * get_rows,
const ggml_tensor * argsort,
const ggml_tensor * clamp,
int n_expert);
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);

View File

@ -583,7 +583,7 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
if (tensor->buffer) {
ggml_backend_buffer_t buffer = tensor->buffer;
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
result.buffer = ctx->remote_ptr;
result.buffer = ctx != nullptr ? ctx->remote_ptr : 0;
} else {
result.buffer = 0;
}

View File

@ -689,6 +689,7 @@ struct vk_device_struct {
vk_pipeline pipeline_gelu_quick[2];
vk_pipeline pipeline_silu[2];
vk_pipeline pipeline_relu[2];
vk_pipeline pipeline_xielu[2];
vk_pipeline pipeline_neg[2];
vk_pipeline pipeline_tanh[2];
vk_pipeline pipeline_sigmoid[2];
@ -990,6 +991,8 @@ struct vk_op_push_constants {
uint32_t KY;
float param1;
float param2;
float param3;
float param4;
};
struct vk_op_glu_push_constants {
@ -1258,6 +1261,7 @@ struct vk_op_im2col_push_constants {
int32_t s0; int32_t s1;
int32_t p0; int32_t p1;
int32_t d0; int32_t d1;
uint32_t batch_IC;
};
struct vk_op_im2col_3d_push_constants {
@ -3973,6 +3977,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_UNARY(gelu_quick)
CREATE_UNARY(silu)
CREATE_UNARY(relu)
CREATE_UNARY(xielu)
CREATE_UNARY(neg)
CREATE_UNARY(tanh)
CREATE_UNARY(sigmoid)
@ -5898,6 +5903,9 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context&
std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), ";
}
std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
GGML_ASSERT(wg0 <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&
wg1 <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
wg2 <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());
GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);
GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size());
@ -8549,6 +8557,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_RELU:
return ctx->device->pipeline_relu[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_XIELU:
return ctx->device->pipeline_xielu[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_NEG:
return ctx->device->pipeline_neg[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_TANH:
@ -9084,6 +9094,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
const uint32_t batch = src1->ne[is_2D ? 3 : 2];
elements = { OW * KW * KH, OH, batch * IC };
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
} break;
case GGML_OP_IM2COL_3D:
{
@ -9695,14 +9707,14 @@ static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& su
ggml_vk_op_f32_opt_step_adamw(
ctx, subctx, dst,
{ (uint32_t)n, 0, 0.0f, 0.0f }
{ (uint32_t)n, 0, 0.0f, 0.0f, 0.0f, 0.0f }
);
}
static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
const size_t n = ggml_nelements(dst->src[0]);
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f });
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f, 0.0f, 0.0f });
}
static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@ -9788,6 +9800,7 @@ static void ggml_vk_arange(ggml_backend_vk_context * ctx, vk_context& subctx, gg
1,
ggml_get_op_params_f32(dst, 0),
ggml_get_op_params_f32(dst, 2),
0.0f, 0.0f,
};
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_ARANGE);
@ -9809,6 +9822,7 @@ static void ggml_vk_fill(ggml_backend_vk_context * ctx, vk_context& subctx, ggml
1,
ggml_get_op_params_f32(dst, 0),
0.0f,
0.0f, 0.0f,
};
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_FILL);
@ -9924,13 +9938,13 @@ static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
}
static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });
}
static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
float * op_params = (float *)dst->op_params;
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
}
static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
@ -9941,7 +9955,7 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
const float eps = float_op_params[1];
const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f });
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f, 0.0f, 0.0f });
}
static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
@ -10110,16 +10124,26 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
float * op_params = (float *)dst->op_params;
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
}
static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
float * op_params = (float *)dst->op_params;
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
}
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });
}
static void ggml_vk_xielu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
float * op_params = (float *)dst->op_params;
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY,
{
(uint32_t)ggml_nelements(src0), 0,
op_params[1], op_params[2], op_params[3], op_params[4]
}
);
}
static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@ -10244,7 +10268,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
float * op_params = (float *)dst->op_params;
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] });
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1], 0.0f, 0.0f });
}
static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) {
@ -10541,11 +10565,11 @@ static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, co
}
static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f });
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f, 0.0f, 0.0f });
}
static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });
}
static void ggml_vk_solve_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@ -10587,6 +10611,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
const uint32_t pelements = OW * KW * KH;
const uint32_t batch = src1->ne[is_2D ? 3 : 2];
const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
const vk_buffer d_buf = d_buf_ctx->dev_buffer;
@ -10599,7 +10624,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
IC, IW, IH, OW, OH, KW, KH,
pelements,
IC * KH * KW,
s0, s1, p0, p1, d0, d1,
s0, s1, p0, p1, d0, d1, batch * IC
});
}
@ -10804,7 +10829,7 @@ static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx
static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
const float * op_params = (const float *)dst->op_params;
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f });
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f, 0.0f, 0.0f });
}
#ifdef GGML_VULKAN_RUN_TESTS
@ -12050,6 +12075,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_UNARY_OP_TRUNC:
ggml_vk_unary(ctx, compute_ctx, src0, node);
break;
case GGML_UNARY_OP_XIELU:
ggml_vk_xielu(ctx, compute_ctx, src0, node);
break;
default:
return false;
}
@ -12920,24 +12948,43 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
const ggml_tensor * softmax;
const ggml_tensor * weights;
const ggml_tensor * get_rows;
const ggml_tensor * argsort;
switch (mode) {
case TOPK_MOE_EARLY_SOFTMAX_NORM:
softmax = cgraph->nodes[node_idx + 0];
weights = cgraph->nodes[node_idx + 9];
get_rows = cgraph->nodes[node_idx + 4];
argsort = cgraph->nodes[node_idx + 2];
break;
case TOPK_MOE_EARLY_SOFTMAX:
softmax = cgraph->nodes[node_idx + 0];
weights = cgraph->nodes[node_idx + 4];
get_rows = cgraph->nodes[node_idx + 4];
argsort = cgraph->nodes[node_idx + 2];
break;
case TOPK_MOE_LATE_SOFTMAX:
softmax = cgraph->nodes[node_idx + 4];
weights = cgraph->nodes[node_idx + 5];
get_rows = cgraph->nodes[node_idx + 2];
argsort = cgraph->nodes[node_idx + 0];
break;
default:
return false;
}
ggml_tensor * probs = get_rows->src[0];
if (probs->op != GGML_OP_RESHAPE) {
return false;
}
probs = probs->src[0];
ggml_tensor * selection_probs = argsort->src[0];
if (probs != selection_probs) {
return false;
}
const float * op_params = (const float *)softmax->op_params;
float scale = op_params[0];
@ -13502,7 +13549,8 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_RMS_NORM && graph->nodes[j]->op == GGML_OP_MUL) &&
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT && graph->nodes[j]->op == GGML_OP_ADD) &&
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID) &&
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL)) {
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL) &&
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD)) {
ok = false;
break;
}
@ -13842,6 +13890,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_XIELU:
case GGML_UNARY_OP_NEG:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_SIGMOID:
@ -14747,7 +14796,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
} else if (tensor->op == GGML_OP_LOG) {
tensor_clone = ggml_log(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_TRI) {
tensor_clone = ggml_tri(ggml_ctx, src_clone[0], ggml_get_op_params_i32(tensor, 0));
tensor_clone = ggml_tri(ggml_ctx, src_clone[0], (ggml_tri_type)ggml_get_op_params_i32(tensor, 0));
} else if (tensor->op == GGML_OP_DIAG) {
tensor_clone = ggml_diag(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_CLAMP) {
@ -14835,6 +14884,13 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
case GGML_UNARY_OP_RELU:
tensor_clone = ggml_relu(ggml_ctx, src_clone[0]);
break;
case GGML_UNARY_OP_XIELU:
tensor_clone = ggml_xielu(ggml_ctx, src_clone[0], 0, 0, 0, 0);
ggml_set_op_params_f32(tensor_clone, 1, ggml_get_op_params_f32(tensor, 1));
ggml_set_op_params_f32(tensor_clone, 2, ggml_get_op_params_f32(tensor, 2));
ggml_set_op_params_f32(tensor_clone, 3, ggml_get_op_params_f32(tensor, 3));
ggml_set_op_params_f32(tensor_clone, 4, ggml_get_op_params_f32(tensor, 4));
break;
case GGML_UNARY_OP_NEG:
tensor_clone = ggml_neg(ggml_ctx, src_clone[0]);
break;

View File

@ -6,4 +6,6 @@ layout (push_constant) uniform parameter
uint KY;
float param1;
float param2;
float param3;
float param4;
} p;

View File

@ -19,6 +19,7 @@ layout (push_constant) uniform parameter
int s0; int s1;
int p0; int p1;
int d0; int d1;
uint batch_IC;
} p;
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
@ -34,12 +35,12 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
layout (buffer_reference) buffer D_ptr {D_TYPE d;};
#endif
void main() {
void im2col(const uint y, const uint z) {
const uint gidx = gl_GlobalInvocationID.x;
const uint oh = gl_GlobalInvocationID.y;
const uint batch = gl_GlobalInvocationID.z / p.IC;
const uint ic = gl_GlobalInvocationID.z % p.IC;
const uint oh = y;
const uint batch = z / p.IC;
const uint ic = z % p.IC;
const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH);
@ -101,3 +102,15 @@ void main() {
#endif
}
}
void main() {
uint y = gl_GlobalInvocationID.y;
while (y < p.OH) {
uint z = gl_GlobalInvocationID.z;
while (z < p.batch_IC) {
im2col(y, z);
z += gl_NumWorkGroups.z;
}
y += gl_NumWorkGroups.y;
}
}

View File

@ -11,36 +11,54 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
const uint y_idx = i * QUANT_K + 16 * itid;
const uint nibble_shift = 4 * (itid & 1);
const uint ib32 = itid / 2; // 0..7
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
// Precompute db multiplication factors
float db_vals[NUM_ROWS];
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const float d = float(data_a[ibi].d);
const uint scale = (data_a[ibi].scales[ib32] >> nibble_shift) & 0xF;
const float db = d * (0.5 + scale) * 0.25;
const uint scale_raw = data_a[ibi].scales[ib32];
const uint scale = (scale_raw >> nibble_shift) & 0xF;
// Merge constant calculations d * (0.5 + scale) * 0.25 = d*0.125 + d*scale*0.25
db_vals[n] = d * (0.125f + float(scale) * 0.25f);
ibi += num_blocks_per_row;
}
ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
// Preload grid and sign data for all l values
vec4 grid0_vals[2], grid1_vals[2];
uint sign_vals[2], sign7_vals[2];
[[unroll]] for (uint l = 0; l < 2; ++l) {
const uint qs = data_a[ibi].qs[2 * itid + l];
const uint sign = qs >> 9;
const uint sign7 = bitCount(sign);
const vec4 grid0 = vec4(unpack8(iq2xs_grid[qs & 511].x));
const vec4 grid1 = vec4(unpack8(iq2xs_grid[qs & 511].y));
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
FLOAT_TYPE sum =
fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x),
fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y),
fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z),
fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w),
fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x),
fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y),
fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z),
fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w),
FLOAT_TYPE(0.0)))))))));
temp[j][n] = fma(db, sum, temp[j][n]);
sign_vals[l] = qs >> 9;
sign7_vals[l] = bitCount(sign_vals[l]);
const uvec2 grid_data = iq2xs_grid[qs & 511];
grid0_vals[l] = vec4(unpack8(grid_data.x));
grid1_vals[l] = vec4(unpack8(grid_data.y));
}
// Preload B data for all j columns (reduce repeated index calculations)
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
[[unroll]] for (uint l = 0; l < 2; ++l) {
const uint sign = sign_vals[l];
const uint sign7 = sign7_vals[l];
const vec4 grid0 = grid0_vals[l];
const vec4 grid1 = grid1_vals[l];
// Precompute indices
const uint b_idx = (j * p.batch_stride_b + b_offset + y_idx) / 4 + 2 * l;
const vec4 b0 = vec4(data_b_v4[b_idx + 0]);
const vec4 b4 = vec4(data_b_v4[b_idx + 1]);
sum +=
fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x),
fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y),
fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z),
fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w),
fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x),
fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y),
fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z),
fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w),
FLOAT_TYPE(0.0)))))))));
}
temp[j][n] = fma(FLOAT_TYPE(db_vals[n]), sum, temp[j][n]);
}
ibi += num_blocks_per_row;
}

View File

@ -853,6 +853,8 @@ void process_shaders() {
string_to_spv("hardswish_f32", "hardswish.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("xielu_f16", "xielu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("xielu_f32", "xielu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("tri_f16", "tri.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("tri_f32", "tri.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});

View File

@ -0,0 +1,35 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
float x = float(data_a[i]);
float alpha_n = p.param1;
float alpha_p = p.param2;
float beta = p.param3;
float eps = p.param4;
if (x > 0.0f) {
x = alpha_p * x * x + beta * x;
} else {
const float min_x_eps = min(x, eps);
x = (exp(min_x_eps) - 1 - x) * alpha_n + beta * x;
}
data_d[i] = D_TYPE(x);
}

View File

@ -5118,25 +5118,36 @@ struct test_top_k : public test_case {
}
};
enum MoeGatingFunc {
GATING_FUNC_SOFTMAX,
GATING_FUNC_SIGMOID,
GATING_FUNC_SOFTMAX_WEIGHT,
};
struct test_topk_moe : public test_case {
const std::array<int64_t, 4> ne;
const int n_expert_used;
const bool with_norm;
const bool delayed_softmax;
const bool bias_probs;
const MoeGatingFunc gating_func;
const float scale_w;
test_topk_moe(std::array<int64_t, 4> ne = { 10, 5, 1, 1 },
int n_expert_used = 1,
bool with_norm = false,
bool delayed_softmax = false) :
bool bias_probs = false,
MoeGatingFunc gating_func = GATING_FUNC_SOFTMAX,
float scale_w = 0.0f) :
ne(ne),
n_expert_used(n_expert_used),
with_norm(with_norm),
delayed_softmax(delayed_softmax) {
bias_probs(bias_probs),
gating_func(gating_func),
scale_w(scale_w) {
GGML_ASSERT(n_expert_used <= ne[0]);
GGML_ASSERT(!(with_norm && delayed_softmax));
}
std::string vars() override { return VARS_TO_STR4(ne, n_expert_used, with_norm, delayed_softmax); }
std::string vars() override { return VARS_TO_STR6(ne, n_expert_used, with_norm, bias_probs, gating_func, scale_w); }
std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
@ -5150,28 +5161,47 @@ struct test_topk_moe : public test_case {
const int n_tokens = ne[1];
ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
ggml_tensor * probs = delayed_softmax ? logits : ggml_soft_max(ctx, logits);
ggml_tensor * selected_experts = ggml_argsort_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens]
ggml_tensor * probs =
(gating_func == GATING_FUNC_SOFTMAX) ? ggml_soft_max(ctx, logits) :
(gating_func == GATING_FUNC_SIGMOID) ? ggml_sigmoid(ctx, logits) : logits;
ggml_set_name(probs, "probs");
ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
ggml_tensor * selection_probs = probs;
if (bias_probs) {
ggml_tensor * exp_probs_b = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
ggml_set_name(exp_probs_b, "exp_probs_b");
selection_probs = ggml_add(ctx, probs, exp_probs_b);
ggml_set_name(selection_probs, "selection_probs");
}
if (delayed_softmax) {
out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens);
out = ggml_soft_max(ctx, out); // [n_expert_used, n_tokens]
out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens);
ggml_tensor * selected_experts = ggml_argsort_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
ggml_set_name(selected_experts, "selected_experts");
ggml_tensor * weights = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
ggml_set_name(weights, "weights");
if (gating_func == GATING_FUNC_SOFTMAX_WEIGHT) {
weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
weights = ggml_soft_max(ctx, weights); // [n_expert_used, n_tokens]
weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens);
}
if (with_norm) {
out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens);
ggml_tensor * weights_sum = ggml_sum_rows(ctx, out); // [1, n_tokens]
weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights); // [1, n_tokens]
ggml_set_name(weights_sum, "weights_sum");
weights_sum = ggml_clamp(ctx, weights_sum, 6.103515625e-5, INFINITY);
out = ggml_div(ctx, out, weights_sum); // [n_expert_used, n_tokens]
out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens);
weights = ggml_div(ctx, weights, weights_sum); // [n_expert_used, n_tokens]
weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens);
}
ggml_set_name(out, "out");
return out;
if (scale_w) {
weights = ggml_scale(ctx, weights, scale_w);
}
ggml_set_name(weights, "weights");
return weights;
}
};
@ -6900,6 +6930,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2560}, {3, 3, 1, 2560}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2560}, {3, 3, 2, 2560}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {5, 5, 1, 32}, {3, 4, 1, 32}, 1, 1, 0, 0, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {2, 2, 1536, 729}, {2, 2, 1536, 4096}, 1, 1, 0, 0, 1, 1, true));
// im2col 3D
test_cases.emplace_back(new test_im2col_3d(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));
@ -7991,19 +8022,22 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}
for (bool with_norm : {false, true}) {
test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm));
test_cases.emplace_back(new test_topk_moe({31, 22, 1, 1}, 8, with_norm));
test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm));
test_cases.emplace_back(new test_topk_moe({40, 22, 1, 1}, 8, with_norm));
test_cases.emplace_back(new test_topk_moe({71, 22, 1, 1}, 8, with_norm));
test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm));
test_cases.emplace_back(new test_topk_moe({129, 1, 1, 1}, 128, with_norm));
for (auto gate : {GATING_FUNC_SOFTMAX, GATING_FUNC_SIGMOID, GATING_FUNC_SOFTMAX_WEIGHT}) {
for (bool with_norm : {false, true}) {
for (bool bias_probs : {false, true}) {
for (float scale_w : {0.0f, 2.0f}) {
test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm, bias_probs, gate, scale_w));
test_cases.emplace_back(new test_topk_moe({31, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
test_cases.emplace_back(new test_topk_moe({40, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
test_cases.emplace_back(new test_topk_moe({71, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w));
test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm, bias_probs, gate, scale_w));
test_cases.emplace_back(new test_topk_moe({129, 1, 1, 1}, 128, with_norm, bias_probs, gate, scale_w));
}
}
}
}
test_cases.emplace_back(new test_topk_moe({ 8, 22, 1, 1 }, 4, /*with_norm*/ false, /*delayed_softmax*/ true));
test_cases.emplace_back(new test_topk_moe({ 32, 22, 1, 1 }, 8, /*with_norm*/ false, /*delayed_softmax*/ true));
#if 0
// these tests are disabled to save execution time, sbut they can be handy for debugging
test_cases.emplace_back(new test_llama(2, true));