ggml-zendnn : add MUL_MAT_ID op support for MoE models (#21315)
* ggml-zendnn : add MUL_MAT_ID op support for MoE models - Add MUL_MAT_ID op acceleration for Mixture-of-Experts models - MUL_MAT_ID op fallback to CPU backend if total experts > 32 - Point ZenDNN lib to latest bits ZenDNN-2026-WW13 * ggml-zendnn : add braces to sgemm failure condition for consistency Co-authored-by: Aaron Teo <taronaeo@gmail.com> --------- Co-authored-by: Aaron Teo <taronaeo@gmail.com>
This commit is contained in:
parent
b069b10ab4
commit
f1ac84119c
|
|
@ -57,13 +57,14 @@ ZenDNN is optimized for AMD EPYC™ processors and AMD Ryzen™ processors based
|
|||
|
||||
## Supported Operations
|
||||
|
||||
The ZenDNN backend currently accelerates **matrix multiplication (MUL_MAT)** operations only. Other operations are handled by the standard CPU backend.
|
||||
The ZenDNN backend accelerates **matrix multiplication (MUL_MAT)** and **expert-based matrix multiplication (MUL_MAT_ID)** operations. Other operations are handled by the standard CPU backend.
|
||||
|
||||
| Operation | Status | Notes |
|
||||
|:-------------|:-------:|:----------------------------------------------:|
|
||||
| MUL_MAT | Support | Accelerated via ZenDNN LowOHA MatMul |
|
||||
| MUL_MAT_ID | Support | Accelerated via ZenDNN LowOHA MatMul (MoE) |
|
||||
|
||||
*Note:* Since only MUL_MAT is accelerated, models will benefit most from ZenDNN when matrix multiplications dominate the computational workload (which is typical for transformer-based LLMs).
|
||||
*Note:* Since MUL_MAT and MUL_MAT_ID are accelerated, models will benefit most from ZenDNN when matrix multiplications dominate the computational workload (which is typical for transformer-based LLMs and Mixture-of-Experts models).
|
||||
|
||||
## DataType Supports
|
||||
|
||||
|
|
@ -181,7 +182,7 @@ For detailed profiling and logging options, refer to the [ZenDNN Logging Documen
|
|||
|
||||
## Known Issues
|
||||
|
||||
- **Limited operation support**: Currently only matrix multiplication (MUL_MAT) is accelerated via ZenDNN. Other operations fall back to the standard CPU backend.
|
||||
- **Limited operation support**: Currently matrix multiplication (MUL_MAT) and expert-based matrix multiplication (MUL_MAT_ID) are accelerated via ZenDNN. Other operations fall back to the standard CPU backend. Future updates may expand supported operations.
|
||||
- **BF16 support**: BF16 operations require AMD Zen 4 or Zen 5 architecture (EPYC 9004/9005 series). On older CPUs, operations will use FP32.
|
||||
- **NUMA awareness**: For multi-socket systems, manual NUMA binding may be required for optimal performance.
|
||||
|
||||
|
|
@ -216,4 +217,4 @@ Please add the **[ZenDNN]** prefix/tag in issues/PRs titles to help the ZenDNN-t
|
|||
|
||||
## TODO
|
||||
|
||||
- Expand operation support beyond MUL_MAT (attention operations, activations, etc.)
|
||||
- Expand operation support beyond MUL_MAT and MUL_MAT_ID (attention operations, activations, etc.)
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ Legend:
|
|||
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ❌ | 🟡 | ❌ |
|
||||
| NEG | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
|
|
|
|||
9986
docs/ops/ZenDNN.csv
9986
docs/ops/ZenDNN.csv
File diff suppressed because it is too large
Load Diff
|
|
@ -28,7 +28,7 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF")
|
|||
ExternalProject_Add(
|
||||
zendnn
|
||||
GIT_REPOSITORY https://github.com/amd/ZenDNN.git
|
||||
GIT_TAG a18adf8c605fb5f5e52cefd7eda08a7b18febbaf # ZenDNN-2026-WW08
|
||||
GIT_TAG f79f7321a1add65ced6397a6bfab7edba6e3e14e # ZenDNN-2026-WW13
|
||||
PREFIX ${ZENDNN_PREFIX}
|
||||
SOURCE_DIR ${ZENDNN_SOURCE_DIR}
|
||||
BINARY_DIR ${ZENDNN_BUILD_DIR}
|
||||
|
|
|
|||
|
|
@ -190,6 +190,170 @@ static void ggml_zendnn_compute_forward_mul_mat(
|
|||
}
|
||||
}
|
||||
|
||||
struct mmid_row_mapping {
|
||||
int32_t i1;
|
||||
int32_t i2;
|
||||
};
|
||||
|
||||
static void ggml_zendnn_compute_forward_mul_mat_id(
|
||||
ggml_backend_zendnn_context * ctx,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0]; // expert weights
|
||||
const ggml_tensor * src1 = dst->src[1]; // inputs
|
||||
const ggml_tensor * ids = dst->src[2]; // expert ids
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
// exit for no tokens to process
|
||||
if (ne2 == 0 || ne11 == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_type const vec_dot_type = src0->type;
|
||||
ggml_from_float_t const from_float = ggml_get_type_traits(vec_dot_type)->from_float_ref;
|
||||
|
||||
// we don't support permuted src0 or src1
|
||||
GGML_ASSERT(nb00 == ggml_type_size(src0->type));
|
||||
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
|
||||
|
||||
// dst cannot be transposed or permuted
|
||||
GGML_ASSERT(nb0 == sizeof(float));
|
||||
GGML_ASSERT(nb0 <= nb1);
|
||||
GGML_ASSERT(nb1 <= nb2);
|
||||
GGML_ASSERT(nb2 <= nb3);
|
||||
|
||||
GGML_ASSERT(ne03 == 1);
|
||||
GGML_ASSERT(ne13 == 1);
|
||||
GGML_ASSERT(ne3 == 1);
|
||||
|
||||
// row groups
|
||||
const int n_ids = ids->ne[0]; // n_expert_used
|
||||
const int n_as = ne02; // n_experts
|
||||
|
||||
std::vector<int64_t> matrix_row_counts(n_as, 0);
|
||||
std::vector<std::vector<mmid_row_mapping>> matrix_rows(n_as);
|
||||
|
||||
int64_t max_rows = 0;
|
||||
// group rows by expert (preprocessing step)
|
||||
for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
|
||||
for (int id = 0; id < n_ids; ++id) {
|
||||
const int32_t i02 = *(const int32_t *)((const char *)ids->data + iid1*ids->nb[1] + id*ids->nb[0]);
|
||||
|
||||
GGML_ASSERT(i02 >= 0 && i02 < n_as);
|
||||
|
||||
matrix_rows[i02].push_back({id, iid1});
|
||||
matrix_row_counts[i02]++;
|
||||
if (matrix_row_counts[i02] > max_rows) {
|
||||
max_rows = matrix_row_counts[i02];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (max_rows == 0) {
|
||||
return; // no rows to process
|
||||
}
|
||||
|
||||
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
||||
|
||||
// size for converting src1 rows to vec_dot_type if needed
|
||||
const size_t nbw1 = row_size;
|
||||
const size_t nbw2 = nbw1 * ne11;
|
||||
const size_t nbw3 = nbw2 * ne12;
|
||||
const size_t src1_conv_size = (src1->type != vec_dot_type) ? ne13 * nbw3 : 0;
|
||||
|
||||
// size for MoE gather/scatter buffers
|
||||
const size_t wdata_cur_size = max_rows * row_size;
|
||||
const size_t dst_cur_size = max_rows * ggml_row_size(dst->type, ne01);
|
||||
|
||||
// allocate single buffer for all needs
|
||||
const size_t total_size = src1_conv_size + wdata_cur_size + dst_cur_size;
|
||||
if (ctx->work_size < total_size) {
|
||||
ctx->work_data.reset(new char[total_size]);
|
||||
ctx->work_size = total_size;
|
||||
}
|
||||
|
||||
// partition the buffer
|
||||
char * work_data = ctx->work_data.get();
|
||||
char * wdata_cur = work_data + src1_conv_size;
|
||||
char * dst_cur = wdata_cur + wdata_cur_size;
|
||||
|
||||
if (src1->type != vec_dot_type) {
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
|
||||
#pragma omp parallel for collapse(3) num_threads(ctx->n_threads) schedule(static)
|
||||
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
||||
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
||||
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
||||
const float * src1_f32 = (float *)((char *)src1->data + i11*nb11 + i12*nb12 + i13*nb13);
|
||||
void * src1_conv = (char *)work_data + i11*nbw1 + i12*nbw2 + i13*nbw3;
|
||||
from_float(src1_f32, src1_conv, ne10);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const void * wdata = src1->type == vec_dot_type ? src1->data : work_data;
|
||||
|
||||
// process each expert with gather -> gemm -> scatter pattern
|
||||
for (int64_t cur_a = 0; cur_a < n_as; ++cur_a) {
|
||||
const int64_t cne1 = matrix_row_counts[cur_a];
|
||||
|
||||
if (cne1 == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const char * src0_cur = (const char *) src0->data + cur_a*nb02;
|
||||
|
||||
// gather input rows for this expert
|
||||
#pragma omp parallel for num_threads(ctx->n_threads) schedule(static)
|
||||
for (int64_t ir1 = 0; ir1 < cne1; ++ir1) {
|
||||
const mmid_row_mapping & row_mapping = matrix_rows[cur_a][ir1];
|
||||
const int64_t id = row_mapping.i1;
|
||||
const int64_t i11 = id % ne11;
|
||||
const int64_t i12 = row_mapping.i2;
|
||||
|
||||
std::memcpy(
|
||||
wdata_cur + ir1 * row_size,
|
||||
(const char *) wdata + (i11 + i12*ne11) * row_size,
|
||||
row_size
|
||||
);
|
||||
}
|
||||
|
||||
// batched gemm for all tokens in this expert
|
||||
if (!ggml_zendnn_sgemm(ctx,
|
||||
ne01, // m
|
||||
cne1, // n
|
||||
ne10, // k
|
||||
src0_cur,
|
||||
ne00, // lda
|
||||
wdata_cur,
|
||||
ne10, // ldb
|
||||
dst_cur,
|
||||
ne01, // ldc
|
||||
src0->type,
|
||||
vec_dot_type,
|
||||
dst->type)) {
|
||||
GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__);
|
||||
}
|
||||
|
||||
// scatter output rows to destination
|
||||
#pragma omp parallel for num_threads(ctx->n_threads) schedule(static)
|
||||
for (int64_t ir1 = 0; ir1 < cne1; ++ir1) {
|
||||
const mmid_row_mapping & row_mapping = matrix_rows[cur_a][ir1];
|
||||
const int64_t id = row_mapping.i1;
|
||||
const int64_t i1 = id;
|
||||
const int64_t i2 = row_mapping.i2;
|
||||
|
||||
std::memcpy(
|
||||
(char *) dst->data + i1*nb1 + i2*nb2,
|
||||
dst_cur + ir1 * ggml_row_size(dst->type, ne01),
|
||||
ggml_row_size(dst->type, ne01)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// backend interface
|
||||
|
||||
static const char * ggml_backend_zendnn_get_name(ggml_backend_t backend) {
|
||||
|
|
@ -218,6 +382,9 @@ static ggml_status ggml_backend_zendnn_graph_compute(ggml_backend_t backend, ggm
|
|||
case GGML_OP_MUL_MAT:
|
||||
ggml_zendnn_compute_forward_mul_mat(ctx, node);
|
||||
break;
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
ggml_zendnn_compute_forward_mul_mat_id(ctx, node);
|
||||
break;
|
||||
case GGML_OP_NONE:
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_VIEW:
|
||||
|
|
@ -361,6 +528,7 @@ static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const
|
|||
return true;
|
||||
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
{
|
||||
const ggml_tensor * weights = op->src[0];
|
||||
const ggml_tensor * inputs = op->src[1];
|
||||
|
|
@ -374,6 +542,17 @@ static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const
|
|||
ne0 < min_batch || ne1 < min_batch || ne10 < min_batch) {
|
||||
return false;
|
||||
}
|
||||
// MUL_MAT_ID performs best with a moderate number of experts due to its
|
||||
// gather + batched matmul + scatter approach. Future versions will leverage
|
||||
// ZenDNN's grouped_gemm for better scalability with larger expert counts:
|
||||
// https://github.com/amd/ZenDNN/blob/main/docs/operator/lowoha_group_gemm_operator.md
|
||||
if (op->op == GGML_OP_MUL_MAT_ID) {
|
||||
const int64_t n_experts = weights->ne[2];
|
||||
const int64_t max_experts = 32;
|
||||
if (n_experts > max_experts) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
switch (weights->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_BF16:
|
||||
|
|
|
|||
Loading…
Reference in New Issue