Add get_rows implementation
This commit is contained in:
parent
7fbe84cd5f
commit
dc7bc4a25a
|
|
@ -124,6 +124,7 @@ struct webgpu_context_struct {
|
|||
wgpu::ComputePipeline memset_pipeline;
|
||||
wgpu::ComputePipeline mul_mat_pipeline[30][2];
|
||||
wgpu::ComputePipeline set_rows_pipeline;
|
||||
wgpu::ComputePipeline get_rows_pipeline[30];
|
||||
wgpu::ComputePipeline cpy_pipeline;
|
||||
wgpu::ComputePipeline add_pipeline[2];
|
||||
wgpu::ComputePipeline add_ip_pipeline[2];
|
||||
|
|
@ -555,6 +556,45 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
|
|||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->set_rows_pipeline, params, entries, wg_x, ggml_op_name(dst->op));
|
||||
}
|
||||
|
||||
static void ggml_webgpu_get_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) {
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
// Convert byte-strides to element-strides
|
||||
(uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
|
||||
(uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
|
||||
(uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
|
||||
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
||||
// Shape of dst
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3],
|
||||
// Shape of idx
|
||||
(uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2])
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src) },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(idx),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, idx),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, idx) },
|
||||
{ .binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
||||
};
|
||||
|
||||
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
|
||||
uint32_t wg_x = (dst->ne[1] * dst->ne[2] * dst->ne[3] + max_wg_size - 1) / max_wg_size;
|
||||
|
||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->get_rows_pipeline[src->type], params, entries, wg_x,
|
||||
ggml_op_name(dst->op));
|
||||
}
|
||||
|
||||
static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||
|
|
@ -711,43 +751,34 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
|
|||
case GGML_OP_RESHAPE:
|
||||
return false;
|
||||
case GGML_OP_CPY:
|
||||
{
|
||||
ggml_webgpu_cpy(ctx, src0, node);
|
||||
break;
|
||||
}
|
||||
ggml_webgpu_cpy(ctx, src0, node);
|
||||
break;
|
||||
case GGML_OP_SET_ROWS:
|
||||
{
|
||||
ggml_webgpu_set_rows(ctx, src0, src1, node);
|
||||
break;
|
||||
}
|
||||
ggml_webgpu_set_rows(ctx, src0, src1, node);
|
||||
break;
|
||||
case GGML_OP_GET_ROWS:
|
||||
ggml_webgpu_get_rows(ctx, src0, src1, node);
|
||||
break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
ggml_webgpu_mul_mat(ctx, src0, src1, node);
|
||||
break;
|
||||
}
|
||||
ggml_webgpu_mul_mat(ctx, src0, src1, node);
|
||||
break;
|
||||
case GGML_OP_ADD:
|
||||
{
|
||||
if (ggml_webgpu_tensor_equal(src0, node)) {
|
||||
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_ip_pipeline[node->type], true);
|
||||
} else {
|
||||
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type], false);
|
||||
}
|
||||
break;
|
||||
if (ggml_webgpu_tensor_equal(src0, node)) {
|
||||
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_ip_pipeline[node->type], true);
|
||||
} else {
|
||||
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type], false);
|
||||
}
|
||||
break;
|
||||
case GGML_OP_MUL:
|
||||
{
|
||||
if (ggml_webgpu_tensor_equal(src0, node)) {
|
||||
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_ip_pipeline[node->type], true);
|
||||
} else {
|
||||
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type], false);
|
||||
}
|
||||
break;
|
||||
if (ggml_webgpu_tensor_equal(src0, node)) {
|
||||
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_ip_pipeline[node->type], true);
|
||||
} else {
|
||||
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type], false);
|
||||
}
|
||||
break;
|
||||
case GGML_OP_RMS_NORM:
|
||||
{
|
||||
ggml_webgpu_rms_norm(ctx, src0, node);
|
||||
break;
|
||||
}
|
||||
ggml_webgpu_rms_norm(ctx, src0, node);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
|
@ -1079,8 +1110,56 @@ static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
|
|||
max_wg_size_entry(webgpu_ctx));
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F32], wgsl_get_rows_f32,
|
||||
"get_rows_f32", max_wg_size_entry(webgpu_ctx));
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F16], wgsl_get_rows_f16,
|
||||
"get_rows_f16", max_wg_size_entry(webgpu_ctx));
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_I32], wgsl_get_rows_i32,
|
||||
"get_rows_i32", max_wg_size_entry(webgpu_ctx));
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_0], wgsl_get_rows_q4_0,
|
||||
"get_rows_q4_0", max_wg_size_entry(webgpu_ctx));
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_1], wgsl_get_rows_q4_1,
|
||||
"get_rows_q4_1", max_wg_size_entry(webgpu_ctx));
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_0], wgsl_get_rows_q5_0,
|
||||
"get_rows_q5_0", max_wg_size_entry(webgpu_ctx));
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_1], wgsl_get_rows_q5_1,
|
||||
"get_rows_q5_1", max_wg_size_entry(webgpu_ctx));
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q8_0], wgsl_get_rows_q8_0,
|
||||
"get_rows_q8_0", max_wg_size_entry(webgpu_ctx));
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q2_K], wgsl_get_rows_q2_k,
|
||||
"get_rows_q2_k", max_wg_size_entry(webgpu_ctx));
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q3_K], wgsl_get_rows_q3_k,
|
||||
"get_rows_q3_k", max_wg_size_entry(webgpu_ctx));
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_K], wgsl_get_rows_q4_k,
|
||||
"get_rows_q4_k", max_wg_size_entry(webgpu_ctx));
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_K], wgsl_get_rows_q5_k,
|
||||
"get_rows_q5_k", max_wg_size_entry(webgpu_ctx));
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q6_K], wgsl_get_rows_q6_k,
|
||||
"get_rows_q6_k", max_wg_size_entry(webgpu_ctx));
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_XXS],
|
||||
wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", max_wg_size_entry(webgpu_ctx));
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_XS],
|
||||
wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", max_wg_size_entry(webgpu_ctx));
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_S],
|
||||
wgsl_get_rows_iq2_s, "get_rows_iq2_s", max_wg_size_entry(webgpu_ctx));
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ3_XXS],
|
||||
wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", max_wg_size_entry(webgpu_ctx));
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ3_S],
|
||||
wgsl_get_rows_iq3_s, "get_rows_iq3_s", max_wg_size_entry(webgpu_ctx));
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ1_S],
|
||||
wgsl_get_rows_iq1_s, "get_rows_iq1_s", max_wg_size_entry(webgpu_ctx));
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ1_M],
|
||||
wgsl_get_rows_iq1_m, "get_rows_iq1_m", max_wg_size_entry(webgpu_ctx));
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ4_NL],
|
||||
wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", max_wg_size_entry(webgpu_ctx));
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ4_XS],
|
||||
wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", max_wg_size_entry(webgpu_ctx));
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy", max_wg_size_entry(webgpu_ctx));
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy",
|
||||
max_wg_size_entry(webgpu_ctx));
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
|
||||
|
|
@ -1162,6 +1241,33 @@ static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggm
|
|||
return buft->iface.get_name == ggml_backend_webgpu_buffer_type_get_name;
|
||||
}
|
||||
|
||||
static bool ggml_webgpu_supported_qtype(ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
||||
GGML_UNUSED(dev);
|
||||
|
||||
|
|
@ -1183,6 +1289,12 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
|||
case GGML_OP_SET_ROWS:
|
||||
supports_op = (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32);
|
||||
break;
|
||||
case GGML_OP_GET_ROWS:
|
||||
if (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 ||
|
||||
op->src[0]->type == GGML_TYPE_I32 || ggml_webgpu_supported_qtype(op->src[0]->type)) {
|
||||
supports_op = (op->type == GGML_TYPE_F32);
|
||||
}
|
||||
break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
switch (op->src[1]->type) {
|
||||
|
|
@ -1220,9 +1332,11 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
|||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case GGML_OP_RMS_NORM:
|
||||
supports_op = op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
@ -1344,6 +1458,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
|||
ggml_webgpu_init_memset_pipeline(ctx);
|
||||
ggml_webgpu_init_mul_mat_pipeline(ctx);
|
||||
ggml_webgpu_init_set_rows_pipeline(ctx);
|
||||
ggml_webgpu_init_get_rows_pipeline(ctx);
|
||||
ggml_webgpu_init_cpy_pipeline(ctx);
|
||||
ggml_webgpu_init_add_pipeline(ctx);
|
||||
ggml_webgpu_init_mul_pipeline(ctx);
|
||||
|
|
|
|||
|
|
@ -70,8 +70,11 @@ def generate_variants(fname, input_dir, output_dir, outfile):
|
|||
except ValueError:
|
||||
decls_map = {}
|
||||
|
||||
with open(os.path.join(input_dir, "common_decls.tmpl"), "r", encoding="utf-8") as f:
|
||||
common_decls = f.read()
|
||||
decls_map.update(parse_decls(common_decls))
|
||||
|
||||
shader_template = extract_block(text, "SHADER")
|
||||
shader_template = expand_includes(shader_template, input_dir)
|
||||
for variant in variants:
|
||||
if "DECLS" in variant:
|
||||
decls = variant["DECLS"]
|
||||
|
|
@ -85,6 +88,7 @@ def generate_variants(fname, input_dir, output_dir, outfile):
|
|||
|
||||
shader_variant = replace_placeholders(shader_template, variant["REPLS"])
|
||||
final_shader = re.sub(r'\bDECLS\b', decls_code, shader_variant)
|
||||
final_shader = expand_includes(final_shader, input_dir)
|
||||
|
||||
if "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]:
|
||||
output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]])
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue