ggml-webgpu: Add the support of `MUL_MAT_ID` (#21147)
* Add mul_mat_id support to WebGPU * Apply suggestion from @reeselevine --------- Co-authored-by: Reese Levine <reeselevine1@gmail.com>
This commit is contained in:
parent
2e1f0a889e
commit
d0a6dfeb28
|
|
@ -68,7 +68,7 @@ Legend:
|
|||
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ❌ | 🟡 | ❌ |
|
||||
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ |
|
||||
| NEG | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
|
|
|
|||
1145
docs/ops/WebGPU.csv
1145
docs/ops/WebGPU.csv
File diff suppressed because it is too large
Load Diff
|
|
@ -658,6 +658,26 @@ struct ggml_webgpu_mul_mat_shader_decisions {
|
|||
uint32_t mul_mat_wg_size;
|
||||
};
|
||||
|
||||
/** MUL_MAT_ID **/
|
||||
|
||||
struct ggml_webgpu_mul_mat_id_pipeline_key {
|
||||
ggml_type src0_type;
|
||||
ggml_type src1_type;
|
||||
|
||||
bool operator==(const ggml_webgpu_mul_mat_id_pipeline_key & other) const {
|
||||
return src0_type == other.src0_type && src1_type == other.src1_type;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_mul_mat_id_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_mul_mat_id_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.src0_type);
|
||||
ggml_webgpu_hash_combine(seed, key.src1_type);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
/** Cpy **/
|
||||
|
||||
struct ggml_webgpu_cpy_pipeline_key {
|
||||
|
|
@ -797,7 +817,10 @@ class ggml_webgpu_shader_lib {
|
|||
std::unordered_map<ggml_webgpu_mul_mat_vec_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_vec_pipeline_key_hash>
|
||||
mul_mat_vec_pipelines; // fast mat-vec (n==1)
|
||||
std::unordered_map<ggml_webgpu_mul_mat_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_pipeline_key_hash>
|
||||
mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup)
|
||||
mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup)
|
||||
std::unordered_map<int, webgpu_pipeline> mul_mat_id_gather_pipelines; // key is fixed
|
||||
std::unordered_map<ggml_webgpu_mul_mat_id_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_id_pipeline_key_hash>
|
||||
mul_mat_id_pipelines; // src0_type/src1_type
|
||||
|
||||
std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
|
||||
set_rows_pipelines;
|
||||
|
|
@ -1598,6 +1621,115 @@ class ggml_webgpu_shader_lib {
|
|||
return mul_mat_legacy_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_mul_mat_id_gather_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
auto it = mul_mat_id_gather_pipelines.find(1);
|
||||
if (it != mul_mat_id_gather_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
std::vector<std::string> defines;
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_mul_mat_id_gather, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, "mul_mat_id_gather");
|
||||
pipeline.context = decisions;
|
||||
mul_mat_id_gather_pipelines[1] = pipeline;
|
||||
return pipeline;
|
||||
}
|
||||
|
||||
webgpu_pipeline get_mul_mat_id_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_mul_mat_id_pipeline_key key = {
|
||||
.src0_type = context.src0->type,
|
||||
.src1_type = context.src1->type,
|
||||
};
|
||||
|
||||
auto it = mul_mat_id_pipelines.find(key);
|
||||
if (it != mul_mat_id_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "mul_mat_id";
|
||||
defines.push_back("MUL_MAT_ID");
|
||||
|
||||
// src1 type
|
||||
switch (context.src1->type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("SRC1_INNER_TYPE=f32");
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("SRC1_INNER_TYPE=f16");
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported src1 type for mul_mat fast shader");
|
||||
}
|
||||
|
||||
// src0 type
|
||||
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
|
||||
const char * src0_name = src0_traits->type_name;
|
||||
|
||||
switch (context.src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("SRC0_INNER_TYPE=f32");
|
||||
defines.push_back("FLOAT");
|
||||
defines.push_back("INIT_SRC0_SHMEM_FLOAT");
|
||||
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
|
||||
variant += "_f32";
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("SRC0_INNER_TYPE=f16");
|
||||
defines.push_back("FLOAT");
|
||||
defines.push_back("INIT_SRC0_SHMEM_FLOAT");
|
||||
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
|
||||
variant += "_f16";
|
||||
break;
|
||||
default:
|
||||
{
|
||||
std::string type_upper = src0_name;
|
||||
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
||||
|
||||
defines.push_back("BYTE_HELPERS");
|
||||
defines.push_back("INIT_SRC0_SHMEM_" + type_upper);
|
||||
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
|
||||
defines.push_back("U32_DEQUANT_HELPERS");
|
||||
defines.push_back("SRC0_INNER_TYPE=u32");
|
||||
|
||||
variant += std::string("_") + src0_name;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
defines.push_back("SCALAR");
|
||||
|
||||
// Tiles
|
||||
defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u");
|
||||
defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u");
|
||||
defines.push_back("TILE_K=" + std::to_string(WEBGPU_MUL_MAT_TILE_K) + "u");
|
||||
|
||||
defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u");
|
||||
defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u");
|
||||
|
||||
// variant suffix for src1 type
|
||||
variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16");
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_mul_mat_id, defines);
|
||||
|
||||
auto decisions = std::make_shared<ggml_webgpu_mul_mat_shader_decisions>();
|
||||
decisions->tile_k = WEBGPU_MUL_MAT_TILE_K;
|
||||
decisions->tile_m = WEBGPU_MUL_MAT_TILE_M;
|
||||
decisions->tile_n = WEBGPU_MUL_MAT_TILE_N;
|
||||
decisions->wg_size_m = WEBGPU_MUL_MAT_WG_SIZE_M;
|
||||
decisions->wg_size_n = WEBGPU_MUL_MAT_WG_SIZE_N;
|
||||
decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE_M * WEBGPU_MUL_MAT_WG_SIZE_N;
|
||||
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
mul_mat_id_pipelines[key] = pipeline;
|
||||
return mul_mat_id_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_unary_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
const bool is_unary = context.dst->op == GGML_OP_UNARY;
|
||||
const int op = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op;
|
||||
|
|
|
|||
|
|
@ -1376,6 +1376,163 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
|||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, wg_y);
|
||||
}
|
||||
|
||||
static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx,
|
||||
wgpu::CommandEncoder & encoder,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * src2,
|
||||
ggml_tensor * dst) {
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src0,
|
||||
.src1 = src1,
|
||||
.src2 = src2,
|
||||
.dst = dst,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
};
|
||||
|
||||
// Get or create pipeline
|
||||
webgpu_pipeline gather_pipeline, main_pipeline;
|
||||
|
||||
std::vector<webgpu_pipeline> pipelines;
|
||||
std::vector<std::vector<uint32_t>> params_list;
|
||||
std::vector<std::vector<wgpu::BindGroupEntry>> entries_list;
|
||||
std::vector<std::pair<uint32_t, uint32_t>> workgroups_list;
|
||||
|
||||
gather_pipeline = ctx->shader_lib->get_mul_mat_id_gather_pipeline(shader_lib_ctx);
|
||||
main_pipeline = ctx->shader_lib->get_mul_mat_id_pipeline(shader_lib_ctx);
|
||||
|
||||
const uint32_t param_n_expert = (uint32_t) src0->ne[2];
|
||||
const uint32_t param_n_expert_used = (uint32_t) dst->ne[1];
|
||||
const uint32_t param_n_tokens = (uint32_t) dst->ne[2];
|
||||
|
||||
// params for mul_mat_id_gather.wgsl
|
||||
std::vector<uint32_t> gather_params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)),
|
||||
param_n_expert,
|
||||
param_n_expert_used,
|
||||
param_n_tokens,
|
||||
(uint32_t) (src2->nb[1] / ggml_type_size(src2->type)),
|
||||
};
|
||||
|
||||
const size_t dst_offset = ggml_webgpu_tensor_offset(dst);
|
||||
const size_t gathered_buf_nbytes = src0->ne[2] * src1->ne[2] * sizeof(uint32_t);
|
||||
|
||||
const size_t gathered_expert_used_align_offset = ROUNDUP_POW2(
|
||||
dst_offset + ggml_nbytes(dst), ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
|
||||
const size_t gathered_tokens_align_offset =
|
||||
ROUNDUP_POW2(gathered_expert_used_align_offset + gathered_buf_nbytes,
|
||||
ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
|
||||
const size_t gathered_count_ids_align_offset =
|
||||
ROUNDUP_POW2(gathered_tokens_align_offset + gathered_buf_nbytes,
|
||||
ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
|
||||
|
||||
const size_t gathered_binding_size = ROUNDUP_POW2(gathered_buf_nbytes, WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
const size_t gathered_count_ids_binding_size =
|
||||
ROUNDUP_POW2(src0->ne[2] * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
|
||||
// bind group entries for mul_mat_id_gather.wgsl
|
||||
std::vector<wgpu::BindGroupEntry> gather_entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src2),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src2),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src2) },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = gathered_expert_used_align_offset,
|
||||
.size = gathered_binding_size },
|
||||
{ .binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = gathered_tokens_align_offset,
|
||||
.size = gathered_binding_size },
|
||||
{ .binding = 3,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = gathered_count_ids_align_offset,
|
||||
.size = gathered_count_ids_binding_size },
|
||||
};
|
||||
|
||||
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
|
||||
const uint32_t gather_total_wg = param_n_expert;
|
||||
const uint32_t gather_wg_x = std::min(gather_total_wg, max_wg_per_dim);
|
||||
const uint32_t gather_wg_y = CEIL_DIV(gather_total_wg, gather_wg_x);
|
||||
|
||||
pipelines.push_back(gather_pipeline);
|
||||
params_list.push_back(std::move(gather_params));
|
||||
entries_list.push_back(std::move(gather_entries));
|
||||
workgroups_list.push_back({ gather_wg_x, gather_wg_y });
|
||||
|
||||
// params for mul_mat_id.wgsl
|
||||
std::vector<uint32_t> main_params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
(uint32_t) src0->ne[0],
|
||||
(uint32_t) src0->ne[1],
|
||||
param_n_expert,
|
||||
param_n_expert_used,
|
||||
param_n_tokens,
|
||||
(uint32_t) src1->ne[1],
|
||||
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
|
||||
};
|
||||
|
||||
// bind group entries for mul_mat_id.wgsl
|
||||
std::vector<wgpu::BindGroupEntry> main_entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
|
||||
{ .binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) },
|
||||
{ .binding = 3,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = gathered_expert_used_align_offset,
|
||||
.size = gathered_binding_size },
|
||||
{ .binding = 4,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = gathered_tokens_align_offset,
|
||||
.size = gathered_binding_size },
|
||||
{ .binding = 5,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = gathered_count_ids_align_offset,
|
||||
.size = gathered_count_ids_binding_size },
|
||||
};
|
||||
|
||||
// Calculate workgroup dimensions
|
||||
uint32_t wg_x = 1;
|
||||
uint32_t wg_y = 1;
|
||||
|
||||
auto * main_decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(main_pipeline.context.get());
|
||||
|
||||
uint32_t wg_m;
|
||||
|
||||
uint32_t tile_m_s = main_decisions->tile_m * main_decisions->wg_size_m;
|
||||
uint32_t tile_n_s = main_decisions->tile_n * main_decisions->wg_size_n;
|
||||
wg_m = CEIL_DIV(dst->ne[0], tile_m_s);
|
||||
uint32_t total_gathered = dst->ne[1] * dst->ne[2];
|
||||
uint32_t max_active_experts = std::min((uint32_t) src0->ne[2], total_gathered);
|
||||
uint32_t max_wg_n = CEIL_DIV(total_gathered, tile_n_s) + max_active_experts;
|
||||
uint32_t total_wg = wg_m * max_wg_n;
|
||||
|
||||
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
|
||||
|
||||
pipelines.push_back(main_pipeline);
|
||||
params_list.push_back(std::move(main_params));
|
||||
entries_list.push_back(std::move(main_entries));
|
||||
workgroups_list.push_back({ wg_x, wg_y });
|
||||
|
||||
return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_arena, encoder, pipelines, params_list,
|
||||
entries_list, workgroups_list);
|
||||
}
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
wgpu::CommandEncoder & encoder,
|
||||
|
|
@ -2638,6 +2795,8 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_encode_node(webgpu_context
|
|||
return ggml_webgpu_get_rows(ctx, encoder, src0, src1, node);
|
||||
case GGML_OP_MUL_MAT:
|
||||
return ggml_webgpu_mul_mat(ctx, encoder, src0, src1, node);
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
return ggml_webgpu_mul_mat_id(ctx, encoder, src0, src1, src2, node);
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
#ifndef __EMSCRIPTEN__
|
||||
return ggml_webgpu_flash_attn(ctx, encoder, src0, src1, src2, node->src[3], node->src[4], node);
|
||||
|
|
@ -3082,6 +3241,20 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
|
|||
}
|
||||
}
|
||||
break;
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
{
|
||||
const ggml_tensor * src0 = tensor->src[0];
|
||||
const ggml_tensor * src1 = tensor->src[1];
|
||||
if (src0 && src1) {
|
||||
const size_t gathered_size = sizeof(uint32_t) * tensor->src[0]->ne[2] * tensor->src[1]->ne[2];
|
||||
const size_t gathered_count_ids_size = sizeof(uint32_t) * tensor->src[0]->ne[2];
|
||||
res = ROUNDUP_POW2(
|
||||
res + gathered_size * 2 + gathered_count_ids_size +
|
||||
ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment * 3,
|
||||
WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
}
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
@ -3503,6 +3676,35 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
|||
}
|
||||
break;
|
||||
}
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
switch (src1->type) {
|
||||
case GGML_TYPE_F16:
|
||||
supports_op |= (src0->type == GGML_TYPE_F16);
|
||||
break;
|
||||
case GGML_TYPE_F32:
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
supports_op = true;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
#ifndef __EMSCRIPTEN__
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
}
|
||||
#endif // INIT_SRC0_SHMEM_FLOAT
|
||||
|
||||
#ifndef MUL_MAT_ID
|
||||
#ifdef INIT_SRC1_SHMEM_FLOAT
|
||||
fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {
|
||||
|
|
@ -58,6 +59,7 @@ fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u3
|
|||
}
|
||||
}
|
||||
#endif // INIT_SRC1_SHMEM_FLOAT
|
||||
#endif
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q4_0
|
||||
const BLOCK_SIZE = 32u;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,193 @@
|
|||
enable f16;
|
||||
|
||||
#include "common_decls.tmpl"
|
||||
#include "mul_mat_decls.tmpl"
|
||||
|
||||
#ifdef VEC
|
||||
fn store_val(acc: array<array<f16, TILE_M>, TILE_N>, tn: u32, tm: u32) -> vec4<f32> {
|
||||
return vec4<f32>(f32(acc[tn][tm]), f32(acc[tn][tm + 1]), f32(acc[tn][tm + 2]), f32(acc[tn][tm + 3]));
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef SCALAR
|
||||
fn store_val(acc: array<array<f16, TILE_M>, TILE_N>, tn: u32, tm: u32) -> f32 {
|
||||
return f32(acc[tn][tm]);
|
||||
}
|
||||
#endif
|
||||
|
||||
struct MulMatIdParams {
|
||||
offset_src0: u32,
|
||||
offset_src1: u32,
|
||||
offset_dst: u32,
|
||||
|
||||
k: u32,
|
||||
m: u32,
|
||||
n_expert: u32,
|
||||
n_expert_used: u32,
|
||||
n_tokens: u32,
|
||||
b_ne1: u32,
|
||||
|
||||
stride_01: u32,
|
||||
stride_11: u32,
|
||||
stride_02: u32,
|
||||
stride_12: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // [cols, rows, n_expert]
|
||||
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // [cols, b_ne1, n_tokens]
|
||||
@group(0) @binding(2) var<storage, read_write> dst: array<DST_TYPE>; // [rows, n_expert_used, n_tokens]
|
||||
@group(0) @binding(3) var<storage, read_write> global_gathered_expert_used: array<u32>; // [n_expert][n_tokens]
|
||||
@group(0) @binding(4) var<storage, read_write> global_gathered_tokens: array<u32>; // [n_expert][n_tokens]
|
||||
@group(0) @binding(5) var<storage, read_write> gathered_count_ids: array<u32>; // [n_expert]
|
||||
|
||||
@group(0) @binding(6) var<uniform> params: MulMatIdParams;
|
||||
|
||||
fn get_local_n(thread_id: u32) -> u32 {
|
||||
return thread_id / WORKGROUP_SIZE_M;
|
||||
}
|
||||
fn get_local_m(thread_id: u32) -> u32 {
|
||||
return thread_id % WORKGROUP_SIZE_M;
|
||||
}
|
||||
|
||||
const TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N;
|
||||
const TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M;
|
||||
const TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N;
|
||||
|
||||
var<workgroup> shmem: array<f16, TILE_SRC0_SHMEM + TILE_SRC1_SHMEM>;
|
||||
var<workgroup> gathered_expert_used: array<u32, TILE_N * WORKGROUP_SIZE_N>;
|
||||
var<workgroup> gathered_tokens: array<u32, TILE_N * WORKGROUP_SIZE_N>;
|
||||
|
||||
#ifdef INIT_SRC1_SHMEM_FLOAT
|
||||
fn init_shmem_id_src1(thread_id: u32, offset_src1: u32, rest_token_n: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {
|
||||
let tile_n = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
if (tile_n < rest_token_n) {
|
||||
let global_src10 = k_outer + tile_k;
|
||||
let expert_used_idx = gathered_expert_used[tile_n] % params.b_ne1;
|
||||
let token_idx = gathered_tokens[tile_n];
|
||||
let src1_idx = offset_src1 + token_idx * params.stride_12 + expert_used_idx * params.stride_11 + global_src10;
|
||||
let src1_val = select(
|
||||
SRC1_TYPE(0.0),
|
||||
src1[src1_idx/VEC_SIZE],
|
||||
global_src10 < params.k);
|
||||
store_shmem(SHMEM_TYPE(src1_val), TILE_SRC0_SHMEM + elem_idx);
|
||||
} else {
|
||||
store_shmem(SHMEM_TYPE(0.0), TILE_SRC0_SHMEM + elem_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC1_SHMEM_FLOAT
|
||||
|
||||
@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
|
||||
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_wg: vec3<u32>) {
|
||||
|
||||
let thread_id = local_id.x;
|
||||
let local_m = get_local_m(thread_id);
|
||||
let local_n = get_local_n(thread_id);
|
||||
|
||||
var expert_idx:u32 = 0xFFFFFFFFu;
|
||||
var wg_in_batch:u32 = 0;
|
||||
var wg_sum:u32 = 0;
|
||||
let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M);
|
||||
let wg_linear = wg_id.y * num_wg.x + wg_id.x;
|
||||
|
||||
for (var i = 0u;i < params.n_expert;i += 1) {
|
||||
let wg_n_count = (gathered_count_ids[i] + WORKGROUP_SIZE_N * TILE_N - 1u) / (WORKGROUP_SIZE_N * TILE_N);
|
||||
let wg_per_matrix = wg_m_count * wg_n_count;
|
||||
if (wg_sum <= wg_linear && wg_linear < wg_sum + wg_per_matrix) {
|
||||
expert_idx = i;
|
||||
wg_in_batch = wg_linear - wg_sum;
|
||||
break;
|
||||
}
|
||||
wg_sum += wg_per_matrix;
|
||||
}
|
||||
|
||||
let is_valid = expert_idx != 0xFFFFFFFFu;
|
||||
|
||||
var wg_m: u32 = 0;
|
||||
var wg_n: u32 = 0;
|
||||
var offset_wg_m: u32 = 0;
|
||||
var offset_wg_n: u32 = 0;
|
||||
var rest_token_n: u32 = 0;
|
||||
var src0_batch_offset: u32 = 0;
|
||||
|
||||
wg_m = wg_in_batch % wg_m_count;
|
||||
wg_n = wg_in_batch / wg_m_count;
|
||||
|
||||
offset_wg_m = wg_m * WORKGROUP_SIZE_M * TILE_M;
|
||||
offset_wg_n = wg_n * WORKGROUP_SIZE_N * TILE_N;
|
||||
|
||||
if (is_valid) {
|
||||
rest_token_n = gathered_count_ids[expert_idx] - offset_wg_n;
|
||||
let global_gathered_base = expert_idx * params.n_tokens + offset_wg_n;
|
||||
for (var i = thread_id; i < TILE_N * WORKGROUP_SIZE_N && offset_wg_n + i < gathered_count_ids[expert_idx]; i += TOTAL_WORKGROUP_SIZE) {
|
||||
gathered_expert_used[i] = global_gathered_expert_used[global_gathered_base + i];
|
||||
gathered_tokens[i] = global_gathered_tokens[global_gathered_base + i];
|
||||
}
|
||||
src0_batch_offset = params.offset_src0 + expert_idx * params.stride_02;
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
let output_row_base = offset_wg_m + local_m * TILE_M;
|
||||
let output_col_base = offset_wg_n + local_n * TILE_N;
|
||||
|
||||
let dst2_stride = params.m * params.n_expert_used;
|
||||
let dst1_stride = params.m;
|
||||
|
||||
var acc: array<array<f16, TILE_M>, TILE_N>;
|
||||
|
||||
for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) {
|
||||
|
||||
if (is_valid) {
|
||||
init_shmem_src0(thread_id, src0_batch_offset, offset_wg_m, k_outer);
|
||||
init_shmem_id_src1(thread_id, params.offset_src1, rest_token_n, k_outer);
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
if (is_valid) {
|
||||
let k_end = min(TILE_K, params.k - k_outer);
|
||||
|
||||
for (var k_inner = 0u; k_inner < k_end; k_inner++) {
|
||||
var src0_tile: array<f16, TILE_M>;
|
||||
for (var tm = 0u; tm < TILE_M; tm++) {
|
||||
let src0_m = local_m * TILE_M + tm;
|
||||
let src0_idx = k_inner + src0_m * TILE_K;
|
||||
src0_tile[tm] = shmem[src0_idx];
|
||||
}
|
||||
for (var tn = 0u; tn < TILE_N; tn++) {
|
||||
let src1_n = local_n * TILE_N + tn;
|
||||
let src1_idx = src1_n * TILE_K + k_inner;
|
||||
let src1_val = shmem[TILE_SRC0_SHMEM + src1_idx];
|
||||
for (var tm = 0u; tm < TILE_M; tm++) {
|
||||
acc[tn][tm] += src0_tile[tm] * src1_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
if (is_valid) {
|
||||
for (var tn = 0u; tn < TILE_N; tn++) {
|
||||
let n_idx = output_col_base + tn;
|
||||
if (n_idx < gathered_count_ids[expert_idx]) {
|
||||
let dst1_idx = gathered_expert_used[n_idx - offset_wg_n];
|
||||
let dst2_idx = gathered_tokens[n_idx - offset_wg_n];
|
||||
let dst12_offset = params.offset_dst + dst2_idx * dst2_stride + dst1_idx * dst1_stride;
|
||||
for (var tm = 0u; tm < TILE_M; tm += VEC_SIZE) {
|
||||
let global_row = output_row_base + tm;
|
||||
if (global_row < params.m) {
|
||||
let dst_idx = dst12_offset + global_row;
|
||||
dst[dst_idx/VEC_SIZE] = store_val(acc, tn, tm);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
enable f16;
|
||||
|
||||
struct MulMatIdGatherParams {
|
||||
offset_ids: u32,
|
||||
|
||||
n_expert: u32,
|
||||
n_expert_used: u32,
|
||||
n_tokens: u32,
|
||||
|
||||
stride_ids_1: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> ids: array<i32>; // [n_expert_used, n_tokens]
|
||||
@group(0) @binding(1) var<storage, read_write> global_gathered_expert_used: array<u32>; // [n_expert][n_tokens]
|
||||
@group(0) @binding(2) var<storage, read_write> global_gathered_tokens: array<u32>; // [n_expert][n_tokens]
|
||||
@group(0) @binding(3) var<storage, read_write> gathered_count_ids: array<u32>; // [n_expert]
|
||||
|
||||
@group(0) @binding(4) var<uniform> params: MulMatIdGatherParams;
|
||||
|
||||
var<workgroup> count:atomic<u32>;
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_wg: vec3<u32>) {
|
||||
|
||||
let thread_id = local_id.x;
|
||||
let own_expert = wg_id.y * num_wg.x + wg_id.x; // the expert assigned to this workgroup
|
||||
|
||||
if (own_expert < params.n_expert) {
|
||||
if (thread_id == 0u) {
|
||||
atomicStore(&count, 0);
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
for (var i = thread_id;i < params.n_expert_used * params.n_tokens;i += WG_SIZE) {
|
||||
let row = i / params.n_expert_used;
|
||||
let col = i % params.n_expert_used;
|
||||
let expert = u32(ids[params.offset_ids + row * params.stride_ids_1 + col]);
|
||||
if (own_expert == expert) {
|
||||
let pos = atomicAdd(&count, 1u);
|
||||
let gathered_id = own_expert * params.n_tokens + pos;
|
||||
global_gathered_expert_used[gathered_id] = col;
|
||||
global_gathered_tokens[gathered_id] = row;
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
if (thread_id == 0u) {
|
||||
gathered_count_ids[own_expert] = atomicLoad(&count);
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue