This commit is contained in:
Aman Gupta 2026-03-24 05:35:01 +02:00 committed by GitHub
commit 71ce063194
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 111 additions and 2 deletions

View File

@ -1681,6 +1681,88 @@ static void ggml_compute_forward_mul_mat_id(
}
}
static void ggml_compute_forward_fused_moe_silu(
const struct ggml_compute_params * params,
struct ggml_tensor * node0,
struct ggml_tensor * node1,
struct ggml_tensor * glu_node) {
const struct ggml_tensor * weights_gate;
const struct ggml_tensor * weights_up;
if (glu_node->src[0] == node0) {
weights_gate = node0->src[0];
weights_up = node1->src[0];
} else {
weights_gate = node1->src[0];
weights_up = node0->src[0];
}
const struct ggml_tensor * src1 = node0->src[1];
const struct ggml_tensor * ids = node0->src[2];
const int64_t ne00 = weights_gate->ne[0];
const int64_t ne01 = weights_gate->ne[1];
const size_t gate_nb01 = weights_gate->nb[1];
const size_t gate_nb02 = weights_gate->nb[2];
const size_t up_nb01 = weights_up->nb[1];
const size_t up_nb02 = weights_up->nb[2];
const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
const size_t nb11 = src1->nb[1];
const size_t glu_nb1 = glu_node->nb[1];
const int ith = params->ith;
const int nth = params->nth;
const enum ggml_type type = weights_gate->type;
ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot;
enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float;
const int n_ids = ids->ne[0]; // n_expert_used
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
const char * src1_q = (const char *) src1->data;
if (src1->type != vec_dot_type) {
char * wdata = (char *) params->wdata + ith * (ne11 * row_size + CACHE_LINE_SIZE);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
for (int64_t i11 = 0; i11 < ne11; ++i11) {
from_float((float *)((char *) src1->data + i11*nb11),
(void *)(wdata + i11*row_size),
ne10);
}
src1_q = wdata;
}
// Process each selected expert directly (no row mapping needed)
for (int id = 0; id < n_ids; ++id) {
const int32_t expert_idx = *(const int32_t *) ((const char *) ids->data + id*ids->nb[0]);
const char * gate_cur = (const char *) weights_gate->data + expert_idx * gate_nb02;
const char * up_cur = (const char *) weights_up->data + expert_idx * up_nb02;
const char * src1_col = src1_q;
float * glu_col = (float *) ((char *) glu_node->data + id*glu_nb1);
// Static work division: each thread gets a contiguous range of rows
const int64_t ir0_start = (ith * ne01) / nth;
const int64_t ir0_end = ((ith + 1) * ne01) / nth;
for (int64_t ir0 = ir0_start; ir0 < ir0_end; ++ir0) {
float gate_val, up_val;
vec_dot(ne00, &gate_val, 0, gate_cur + ir0*gate_nb01, 0, src1_col, 0, 1);
vec_dot(ne00, &up_val, 0, up_cur + ir0*up_nb01, 0, src1_col, 0, 1);
glu_col[ir0] = ggml_silu_f32(gate_val) * up_val;
}
}
}
/////////////////////////////////
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
@ -2809,7 +2891,12 @@ struct ggml_cplan ggml_graph_plan(
const int n_as = src0->ne[2];
// src1
if (src1->type != vec_dot_type) {
cur += ggml_row_size(vec_dot_type, ggml_nelements(src1)) + sizeof(int64_t);
size_t quant_buf = ggml_row_size(vec_dot_type, ggml_nelements(src1));
// fused MoE path: each thread needs its own quantization buffer
if (src1->ne[2] == 1) {
quant_buf *= n_tasks;
}
cur += quant_buf + sizeof(int64_t);
}
// matrix_row_counts
cur += n_as * sizeof(int64_t) + sizeof(int64_t);
@ -2981,7 +3068,29 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
continue;
}
ggml_compute_forward(&params, node);
// Try fusion: MUL_MAT_ID + MUL_MAT_ID + GLU
int fused_nodes = 0;
if (node->op == GGML_OP_MUL_MAT_ID) {
enum ggml_op fuse_ops[3] = {GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU};
int outputs[1] = {node_n + 2};
if (ggml_can_fuse_subgraph(cgraph, node_n, 3, fuse_ops, outputs, 1)) {
struct ggml_tensor * node1 = cgraph->nodes[node_n + 1];
struct ggml_tensor * glu = cgraph->nodes[node_n + 2];
// Fused path for `--n-cpu-moe` when n_tokens = 1
if (node->src[1] == node1->src[1] && node->src[2] == node1->src[2] &&
ggml_nrows(node->src[1]) == 1 &&
ggml_get_glu_op(glu) == GGML_GLU_OP_SWIGLU) {
ggml_compute_forward_fused_moe_silu(&params, node, node1, glu);
fused_nodes = 2;
}
}
}
if (fused_nodes == 0) {
ggml_compute_forward(&params, node);
}
node_n += fused_nodes;
if (state->ith == 0 && cplan->abort_callback &&
cplan->abort_callback(cplan->abort_callback_data)) {