ggml webgpu: support for rope,div,sub,glu,scale,cont operators (#16187)
* Work on rope * Simplify inplace operation generation and combine mul/add generation * Work on rope variants * implement neox rope * rope complete * Add sub,div,glu operators * implement scale op * Update cpy shader to handle cont/more types * formatting * Update test vars printing for rope,rms_norm * Avoid ROPE hardcoded constants * Add TODO to change ROPE constants to enum Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * fix TODO comment --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
d1c84a662d
commit
8d78cd2613
|
|
@ -237,6 +237,8 @@
|
||||||
#define GGML_EXIT_SUCCESS 0
|
#define GGML_EXIT_SUCCESS 0
|
||||||
#define GGML_EXIT_ABORTED 1
|
#define GGML_EXIT_ABORTED 1
|
||||||
|
|
||||||
|
// TODO: convert to enum https://github.com/ggml-org/llama.cpp/pull/16187#discussion_r2388538726
|
||||||
|
#define GGML_ROPE_TYPE_NORMAL 0
|
||||||
#define GGML_ROPE_TYPE_NEOX 2
|
#define GGML_ROPE_TYPE_NEOX 2
|
||||||
#define GGML_ROPE_TYPE_MROPE 8
|
#define GGML_ROPE_TYPE_MROPE 8
|
||||||
#define GGML_ROPE_TYPE_VISION 24
|
#define GGML_ROPE_TYPE_VISION 24
|
||||||
|
|
|
||||||
|
|
@ -130,13 +130,15 @@ struct webgpu_context_struct {
|
||||||
wgpu::ComputePipeline set_rows_pipeline;
|
wgpu::ComputePipeline set_rows_pipeline;
|
||||||
wgpu::ComputePipeline get_rows_pipeline[30];
|
wgpu::ComputePipeline get_rows_pipeline[30];
|
||||||
wgpu::ComputePipeline get_rows_f32_no_vec_pipeline;
|
wgpu::ComputePipeline get_rows_f32_no_vec_pipeline;
|
||||||
wgpu::ComputePipeline cpy_pipeline;
|
wgpu::ComputePipeline cpy_pipeline[2][2]; // src type, dst type
|
||||||
wgpu::ComputePipeline add_pipeline[2];
|
wgpu::ComputePipeline add_pipeline[2][2]; // type, inplace
|
||||||
wgpu::ComputePipeline add_ip_pipeline[2];
|
wgpu::ComputePipeline sub_pipeline[2][2]; // type, inplace
|
||||||
wgpu::ComputePipeline mul_pipeline[2];
|
wgpu::ComputePipeline mul_pipeline[2][2]; // type, inplace
|
||||||
wgpu::ComputePipeline mul_ip_pipeline[2];
|
wgpu::ComputePipeline div_pipeline[2][2]; // type, inplace
|
||||||
wgpu::ComputePipeline rms_norm_pipeline;
|
wgpu::ComputePipeline rms_norm_pipeline[2]; // inplace
|
||||||
wgpu::ComputePipeline rms_norm_ip_pipeline;
|
wgpu::ComputePipeline rope_pipeline[2][2][2]; // type, ff, inplace
|
||||||
|
wgpu::ComputePipeline glu_pipeline[7][2][2]; // glu-op, type, split
|
||||||
|
wgpu::ComputePipeline scale_pipeline[2]; // inplace
|
||||||
|
|
||||||
size_t memset_bytes_per_thread;
|
size_t memset_bytes_per_thread;
|
||||||
|
|
||||||
|
|
@ -489,8 +491,9 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor
|
||||||
(uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
|
(uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
|
||||||
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
||||||
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
||||||
// Logical shape — same for both tensors even if permuted
|
// Logical shapes
|
||||||
(uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3]
|
(uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0],
|
||||||
|
(uint32_t) dst->ne[1], (uint32_t) dst->ne[2]
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<wgpu::BindGroupEntry> entries = {
|
std::vector<wgpu::BindGroupEntry> entries = {
|
||||||
|
|
@ -506,7 +509,8 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor
|
||||||
|
|
||||||
size_t max_wg_size = ctx->max_wg_size_x;
|
size_t max_wg_size = ctx->max_wg_size_x;
|
||||||
uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size;
|
uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size;
|
||||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x, ggml_op_name(dst->op));
|
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline[src->type][dst->type], params, entries, wg_x,
|
||||||
|
ggml_op_name(dst->op));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) {
|
static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) {
|
||||||
|
|
@ -649,7 +653,7 @@ static void ggml_webgpu_binary_op(webgpu_context & ctx,
|
||||||
ggml_tensor * src1,
|
ggml_tensor * src1,
|
||||||
ggml_tensor * dst,
|
ggml_tensor * dst,
|
||||||
wgpu::ComputePipeline & pipeline,
|
wgpu::ComputePipeline & pipeline,
|
||||||
bool in_place) {
|
bool inplace) {
|
||||||
std::vector<uint32_t> params = {
|
std::vector<uint32_t> params = {
|
||||||
(uint32_t) ggml_nelements(dst),
|
(uint32_t) ggml_nelements(dst),
|
||||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||||
|
|
@ -678,7 +682,7 @@ static void ggml_webgpu_binary_op(webgpu_context & ctx,
|
||||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1) }
|
.size = ggml_webgpu_tensor_binding_size(ctx, src1) }
|
||||||
};
|
};
|
||||||
if (!in_place) {
|
if (!inplace) {
|
||||||
entries.push_back({ .binding = 2,
|
entries.push_back({ .binding = 2,
|
||||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||||
|
|
@ -691,30 +695,23 @@ static void ggml_webgpu_binary_op(webgpu_context & ctx,
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||||
bool in_place = ggml_webgpu_tensor_equal(src, dst);
|
int inplace = ggml_webgpu_tensor_equal(src, dst);
|
||||||
|
|
||||||
uint32_t eps;
|
|
||||||
memcpy(&eps, dst->op_params, sizeof(float));
|
|
||||||
|
|
||||||
std::vector<uint32_t> params = {
|
std::vector<uint32_t> params = {
|
||||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
||||||
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) (src->nb[1] / ggml_type_size(src->type)),
|
||||||
|
(uint32_t) (src->nb[2] / ggml_type_size(src->type)),
|
||||||
|
(uint32_t) (src->nb[3] / ggml_type_size(src->type)),
|
||||||
|
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) src->ne[0],
|
||||||
|
(uint32_t) src->ne[1],
|
||||||
|
(uint32_t) src->ne[2],
|
||||||
|
(uint32_t) src->ne[3],
|
||||||
|
*(uint32_t *) dst->op_params // epsilon, treated as f32 in the shader
|
||||||
};
|
};
|
||||||
if (!in_place) {
|
|
||||||
params.push_back((uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)));
|
|
||||||
}
|
|
||||||
params.push_back((uint32_t) (src->nb[1] / ggml_type_size(src->type)));
|
|
||||||
params.push_back((uint32_t) (src->nb[2] / ggml_type_size(src->type)));
|
|
||||||
params.push_back((uint32_t) (src->nb[3] / ggml_type_size(src->type)));
|
|
||||||
if (!in_place) {
|
|
||||||
params.push_back((uint32_t) (dst->nb[1] / ggml_type_size(dst->type)));
|
|
||||||
params.push_back((uint32_t) (dst->nb[2] / ggml_type_size(dst->type)));
|
|
||||||
params.push_back((uint32_t) (dst->nb[3] / ggml_type_size(dst->type)));
|
|
||||||
}
|
|
||||||
params.push_back((uint32_t) src->ne[0]);
|
|
||||||
params.push_back((uint32_t) src->ne[1]);
|
|
||||||
params.push_back((uint32_t) src->ne[2]);
|
|
||||||
params.push_back((uint32_t) src->ne[3]);
|
|
||||||
params.push_back(eps); // epsilon, will be bitcast to float in shader
|
|
||||||
|
|
||||||
std::vector<wgpu::BindGroupEntry> entries = {
|
std::vector<wgpu::BindGroupEntry> entries = {
|
||||||
{ .binding = 0,
|
{ .binding = 0,
|
||||||
|
|
@ -722,24 +719,199 @@ static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_t
|
||||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
|
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
|
||||||
.size = ggml_webgpu_tensor_binding_size(ctx, src) }
|
.size = ggml_webgpu_tensor_binding_size(ctx, src) }
|
||||||
};
|
};
|
||||||
if (!in_place) {
|
if (!inplace) {
|
||||||
entries.push_back({ .binding = 1,
|
entries.push_back({ .binding = 1,
|
||||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
||||||
}
|
}
|
||||||
|
|
||||||
wgpu::ComputePipeline pipeline;
|
|
||||||
if (in_place) {
|
|
||||||
pipeline = ctx->rms_norm_ip_pipeline;
|
|
||||||
} else {
|
|
||||||
pipeline = ctx->rms_norm_pipeline;
|
|
||||||
}
|
|
||||||
size_t max_wg_size = ctx->max_wg_size_x;
|
size_t max_wg_size = ctx->max_wg_size_x;
|
||||||
uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
|
uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
|
||||||
|
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->rms_norm_pipeline[inplace], params, entries, wg_x,
|
||||||
|
ggml_op_name(dst->op));
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_webgpu_rope(webgpu_context & ctx,
|
||||||
|
ggml_tensor * src0,
|
||||||
|
ggml_tensor * src1,
|
||||||
|
ggml_tensor * src2,
|
||||||
|
ggml_tensor * dst) {
|
||||||
|
const int inplace = ggml_webgpu_tensor_equal(src0, dst);
|
||||||
|
const int has_freq_factor = (src2 != nullptr);
|
||||||
|
|
||||||
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||||
|
const int mode = ((int32_t *) dst->op_params)[2];
|
||||||
|
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
||||||
|
|
||||||
|
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||||
|
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||||
|
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||||
|
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||||
|
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||||
|
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||||
|
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||||
|
|
||||||
|
int sections[4];
|
||||||
|
memcpy(sections, (int32_t *) dst->op_params + 11, 4 * sizeof(int));
|
||||||
|
|
||||||
|
float theta_scale = powf(freq_base, -2.0f / n_dims);
|
||||||
|
|
||||||
|
float corr_dims[2];
|
||||||
|
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
||||||
|
|
||||||
|
std::vector<uint32_t> params = {
|
||||||
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||||
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||||
|
src2 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,
|
||||||
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
||||||
|
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
||||||
|
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
||||||
|
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) ggml_nelements(src0) / 2,
|
||||||
|
(uint32_t) src0->ne[0],
|
||||||
|
(uint32_t) src0->ne[1],
|
||||||
|
(uint32_t) src0->ne[2],
|
||||||
|
(uint32_t) n_dims,
|
||||||
|
(uint32_t) mode,
|
||||||
|
*(uint32_t *) &theta_scale,
|
||||||
|
*(uint32_t *) &attn_factor,
|
||||||
|
*(uint32_t *) &freq_scale,
|
||||||
|
*(uint32_t *) &ext_factor,
|
||||||
|
*(uint32_t *) &corr_dims[0],
|
||||||
|
*(uint32_t *) &corr_dims[1],
|
||||||
|
(uint32_t) sections[0],
|
||||||
|
(uint32_t) sections[1],
|
||||||
|
(uint32_t) sections[2],
|
||||||
|
(uint32_t) sections[3]
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<wgpu::BindGroupEntry> entries = {
|
||||||
|
{ .binding = 0,
|
||||||
|
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||||
|
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||||
|
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
||||||
|
{ .binding = 1,
|
||||||
|
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||||
|
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||||
|
.size = ggml_webgpu_tensor_binding_size(ctx, src1) }
|
||||||
|
};
|
||||||
|
uint32_t dst_binding = 2;
|
||||||
|
if (has_freq_factor) {
|
||||||
|
dst_binding = 3;
|
||||||
|
entries.push_back({ .binding = 2,
|
||||||
|
.buffer = ggml_webgpu_tensor_buf(src2),
|
||||||
|
.offset = ggml_webgpu_tensor_align_offset(ctx, src2),
|
||||||
|
.size = ggml_webgpu_tensor_binding_size(ctx, src2) });
|
||||||
|
}
|
||||||
|
if (!inplace) {
|
||||||
|
entries.push_back({ .binding = dst_binding,
|
||||||
|
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||||
|
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||||
|
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
||||||
|
}
|
||||||
|
|
||||||
|
wgpu::ComputePipeline pipeline = ctx->rope_pipeline[dst->type][has_freq_factor][inplace];
|
||||||
|
size_t max_wg_size = ctx->max_wg_size_x;
|
||||||
|
uint32_t wg_x = (ggml_nelements(src0) / 2 + max_wg_size - 1) / max_wg_size;
|
||||||
ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
|
ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
const int split = (src1 != nullptr);
|
||||||
|
|
||||||
|
std::vector<uint32_t> params = {
|
||||||
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||||
|
src1 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
|
||||||
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
||||||
|
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
||||||
|
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
||||||
|
src1 != nullptr ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) :
|
||||||
|
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
||||||
|
src1 != nullptr ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) :
|
||||||
|
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
||||||
|
src1 != nullptr ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) :
|
||||||
|
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
||||||
|
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) ggml_nelements(dst),
|
||||||
|
(uint32_t) dst->ne[0],
|
||||||
|
(uint32_t) dst->ne[1],
|
||||||
|
(uint32_t) dst->ne[2],
|
||||||
|
(uint32_t) ((int32_t *) dst->op_params)[1], // swapped
|
||||||
|
*(uint32_t *) &dst->op_params[2], // alpha, for swiglu_oai
|
||||||
|
*(uint32_t *) &dst->op_params[3], // limit, for swiglu_oai
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<wgpu::BindGroupEntry> entries = {
|
||||||
|
{ .binding = 0,
|
||||||
|
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||||
|
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||||
|
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
||||||
|
};
|
||||||
|
uint32_t dst_binding = 1;
|
||||||
|
if (split) {
|
||||||
|
dst_binding = 2;
|
||||||
|
entries.push_back({ .binding = 1,
|
||||||
|
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||||
|
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||||
|
.size = ggml_webgpu_tensor_binding_size(ctx, src1) });
|
||||||
|
}
|
||||||
|
entries.push_back({ .binding = dst_binding,
|
||||||
|
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||||
|
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||||
|
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
||||||
|
|
||||||
|
wgpu::ComputePipeline pipeline = ctx->glu_pipeline[ggml_get_glu_op(dst)][dst->type][split];
|
||||||
|
size_t max_wg_size = ctx->max_wg_size_x;
|
||||||
|
uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size;
|
||||||
|
ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||||
|
int inplace = ggml_webgpu_tensor_equal(src, dst);
|
||||||
|
|
||||||
|
std::vector<uint32_t> params = {
|
||||||
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
||||||
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) (src->nb[1] / ggml_type_size(src->type)),
|
||||||
|
(uint32_t) (src->nb[2] / ggml_type_size(src->type)),
|
||||||
|
(uint32_t) (src->nb[3] / ggml_type_size(src->type)),
|
||||||
|
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
||||||
|
(uint32_t) ggml_nelements(dst),
|
||||||
|
(uint32_t) src->ne[0],
|
||||||
|
(uint32_t) src->ne[1],
|
||||||
|
(uint32_t) src->ne[2],
|
||||||
|
*(uint32_t *) dst->op_params, // scale
|
||||||
|
*(uint32_t *) &dst->op_params[1] // bias
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<wgpu::BindGroupEntry> entries = {
|
||||||
|
{ .binding = 0,
|
||||||
|
.buffer = ggml_webgpu_tensor_buf(src),
|
||||||
|
.offset = ggml_webgpu_tensor_align_offset(ctx, src),
|
||||||
|
.size = ggml_webgpu_tensor_binding_size(ctx, src) }
|
||||||
|
};
|
||||||
|
if (!inplace) {
|
||||||
|
entries.push_back({ .binding = 1,
|
||||||
|
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||||
|
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||||
|
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t max_wg_size = ctx->max_wg_size_x;
|
||||||
|
uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size;
|
||||||
|
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->scale_pipeline[inplace], params, entries, wg_x,
|
||||||
|
ggml_op_name(dst->op));
|
||||||
|
}
|
||||||
|
|
||||||
// Returns true if node has enqueued work into the queue, false otherwise
|
// Returns true if node has enqueued work into the queue, false otherwise
|
||||||
static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
|
static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
|
||||||
if (ggml_is_empty(node)) {
|
if (ggml_is_empty(node)) {
|
||||||
|
|
@ -749,6 +921,7 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
|
||||||
|
|
||||||
ggml_tensor * src0 = node->src[0];
|
ggml_tensor * src0 = node->src[0];
|
||||||
ggml_tensor * src1 = node->src[1];
|
ggml_tensor * src1 = node->src[1];
|
||||||
|
ggml_tensor * src2 = node->src[2];
|
||||||
|
|
||||||
switch (node->op) {
|
switch (node->op) {
|
||||||
// no-ops
|
// no-ops
|
||||||
|
|
@ -759,6 +932,7 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
|
||||||
case GGML_OP_RESHAPE:
|
case GGML_OP_RESHAPE:
|
||||||
return false;
|
return false;
|
||||||
case GGML_OP_CPY:
|
case GGML_OP_CPY:
|
||||||
|
case GGML_OP_CONT:
|
||||||
ggml_webgpu_cpy(ctx, src0, node);
|
ggml_webgpu_cpy(ctx, src0, node);
|
||||||
break;
|
break;
|
||||||
case GGML_OP_SET_ROWS:
|
case GGML_OP_SET_ROWS:
|
||||||
|
|
@ -771,22 +945,41 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
|
||||||
ggml_webgpu_mul_mat(ctx, src0, src1, node);
|
ggml_webgpu_mul_mat(ctx, src0, src1, node);
|
||||||
break;
|
break;
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
if (ggml_webgpu_tensor_equal(src0, node)) {
|
{
|
||||||
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_ip_pipeline[node->type], true);
|
int inplace = ggml_webgpu_tensor_equal(src0, node);
|
||||||
} else {
|
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type][inplace], inplace);
|
||||||
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type], false);
|
break;
|
||||||
|
}
|
||||||
|
case GGML_OP_SUB:
|
||||||
|
{
|
||||||
|
int inplace = ggml_webgpu_tensor_equal(src0, node);
|
||||||
|
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipeline[node->type][inplace], inplace);
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
break;
|
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
if (ggml_webgpu_tensor_equal(src0, node)) {
|
{
|
||||||
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_ip_pipeline[node->type], true);
|
int inplace = ggml_webgpu_tensor_equal(src0, node);
|
||||||
} else {
|
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type][inplace], inplace);
|
||||||
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type], false);
|
break;
|
||||||
|
}
|
||||||
|
case GGML_OP_DIV:
|
||||||
|
{
|
||||||
|
int inplace = ggml_webgpu_tensor_equal(src0, node);
|
||||||
|
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipeline[node->type][inplace], inplace);
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
break;
|
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
ggml_webgpu_rms_norm(ctx, src0, node);
|
ggml_webgpu_rms_norm(ctx, src0, node);
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_ROPE:
|
||||||
|
ggml_webgpu_rope(ctx, src0, src1, src2, node);
|
||||||
|
break;
|
||||||
|
case GGML_OP_GLU:
|
||||||
|
ggml_webgpu_glu(ctx, src0, src1, node);
|
||||||
|
break;
|
||||||
|
case GGML_OP_SCALE:
|
||||||
|
ggml_webgpu_scale(ctx, src0, node);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
@ -1170,40 +1363,153 @@ static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
|
static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
|
||||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy",
|
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||||
ggml_webgpu_max_wg_size_entry(webgpu_ctx));
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F32],
|
||||||
|
wgsl_cpy_f32_f32, "cpy_f32_f32", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F16],
|
||||||
|
wgsl_cpy_f32_f16, "cpy_f32_f16", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F16][GGML_TYPE_F32],
|
||||||
|
wgsl_cpy_f16_f32, "cpy_f16_f32", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F16][GGML_TYPE_F16],
|
||||||
|
wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
|
static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
|
||||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32], wgsl_add_f32, "add_f32",
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][0], wgsl_add_f32, "add_f32",
|
||||||
constants);
|
constants);
|
||||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16], wgsl_add_f16, "add_f16",
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][0], wgsl_add_f16, "add_f16",
|
||||||
constants);
|
constants);
|
||||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F32], wgsl_add_in_place_f32,
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][1], wgsl_add_f32_inplace,
|
||||||
"add_in_place_f32", constants);
|
"add_f32_inplace", constants);
|
||||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F16], wgsl_add_in_place_f16,
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][1], wgsl_add_f16_inplace,
|
||||||
"add_in_place_f16", constants);
|
"add_f16_inplace", constants);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) {
|
||||||
|
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][0], wgsl_sub_f32, "sub_f32",
|
||||||
|
constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][0], wgsl_sub_f16, "sub_f16",
|
||||||
|
constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][1], wgsl_sub_f32_inplace,
|
||||||
|
"sub_f32_inplace", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][1], wgsl_sub_f16_inplace,
|
||||||
|
"sub_f16_inplace", constants);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
|
static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
|
||||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32], wgsl_mul_f32, "mul_f32",
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][0], wgsl_mul_f32, "mul_f32",
|
||||||
constants);
|
constants);
|
||||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16], wgsl_mul_f16, "mul_f16",
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][0], wgsl_mul_f16, "mul_f16",
|
||||||
constants);
|
constants);
|
||||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F32], wgsl_mul_in_place_f32,
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][1], wgsl_mul_f32_inplace,
|
||||||
"mul_in_place_f32", constants);
|
"mul_f32_inplace", constants);
|
||||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F16], wgsl_mul_in_place_f16,
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][1], wgsl_mul_f16_inplace,
|
||||||
"mul_in_place_f16", constants);
|
"mul_f16_inplace", constants);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) {
|
||||||
|
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][0], wgsl_div_f32, "div_f32",
|
||||||
|
constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][0], wgsl_div_f16, "div_f16",
|
||||||
|
constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][1], wgsl_div_f32_inplace,
|
||||||
|
"div_f32_inplace", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][1], wgsl_div_f16_inplace,
|
||||||
|
"div_f16_inplace", constants);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
|
static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
|
||||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline, wgsl_rms_norm, "rms_norm",
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[0], wgsl_rms_norm, "rms_norm",
|
||||||
constants);
|
constants);
|
||||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_ip_pipeline, wgsl_rms_norm_in_place,
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[1], wgsl_rms_norm_inplace,
|
||||||
"rms_norm_in_place", constants);
|
"rms_norm_inplace", constants);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
|
||||||
|
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][0], wgsl_rope_f32,
|
||||||
|
"rope_f32", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][1],
|
||||||
|
wgsl_rope_f32_inplace, "rope_f32_inplace", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][1][0], wgsl_rope_f32_ff,
|
||||||
|
"rope_f32_ff", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][1][1],
|
||||||
|
wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][0][0], wgsl_rope_f16,
|
||||||
|
"rope_f16", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][0][1],
|
||||||
|
wgsl_rope_f16_inplace, "rope_f16_inplace", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][1][0], wgsl_rope_f16_ff,
|
||||||
|
"rope_f16_ff", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][1][1],
|
||||||
|
wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
|
||||||
|
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||||
|
// reglu
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0],
|
||||||
|
wgsl_reglu_f32, "reglu_f32", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0],
|
||||||
|
wgsl_reglu_f16, "reglu_f16", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1],
|
||||||
|
wgsl_reglu_f32_split, "reglu_f32_split", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1],
|
||||||
|
wgsl_reglu_f16_split, "reglu_f16_split", constants);
|
||||||
|
// geglu
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0],
|
||||||
|
wgsl_geglu_f32, "geglu_f32", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0],
|
||||||
|
wgsl_geglu_f16, "geglu_f16", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1],
|
||||||
|
wgsl_geglu_f32_split, "geglu_f32_split", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1],
|
||||||
|
wgsl_geglu_f16_split, "geglu_f16_split", constants);
|
||||||
|
// swiglu
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0],
|
||||||
|
wgsl_swiglu_f32, "swiglu_f32", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0],
|
||||||
|
wgsl_swiglu_f16, "swiglu_f16", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1],
|
||||||
|
wgsl_swiglu_f32_split, "swiglu_f32_split", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1],
|
||||||
|
wgsl_swiglu_f16_split, "swiglu_f16_split", constants);
|
||||||
|
// swiglu_oai
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0],
|
||||||
|
wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1],
|
||||||
|
wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants);
|
||||||
|
// geglu_erf
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0],
|
||||||
|
wgsl_geglu_erf_f32, "geglu_erf_f32", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0],
|
||||||
|
wgsl_geglu_erf_f16, "geglu_erf_f16", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1],
|
||||||
|
wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1],
|
||||||
|
wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants);
|
||||||
|
// geglu_quick
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0],
|
||||||
|
wgsl_geglu_quick_f32, "geglu_quick_f32", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0],
|
||||||
|
wgsl_geglu_quick_f16, "geglu_quick_f16", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1],
|
||||||
|
wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1],
|
||||||
|
wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) {
|
||||||
|
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[0], wgsl_scale_f32, "scale_f32",
|
||||||
|
constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[1], wgsl_scale_f32_inplace,
|
||||||
|
"scale_f32_inplace", constants);
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
|
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
|
||||||
|
|
@ -1287,6 +1593,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||||
|
|
||||||
ggml_tensor * src0 = op->src[0];
|
ggml_tensor * src0 = op->src[0];
|
||||||
ggml_tensor * src1 = op->src[1];
|
ggml_tensor * src1 = op->src[1];
|
||||||
|
|
||||||
// on smaller devices (or CI), tensors may be larger than the max storage buffer size
|
// on smaller devices (or CI), tensors may be larger than the max storage buffer size
|
||||||
if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize ||
|
if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize ||
|
||||||
(src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
|
(src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
|
||||||
|
|
@ -1304,28 +1611,34 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||||
supports_op = true;
|
supports_op = true;
|
||||||
break;
|
break;
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
|
case GGML_OP_SUB:
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (op->src[0]->type == op->type) &&
|
case GGML_OP_DIV:
|
||||||
(op->src[1]->type == op->type);
|
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) &&
|
||||||
|
(src1->type == op->type);
|
||||||
break;
|
break;
|
||||||
case GGML_OP_CPY:
|
case GGML_OP_CPY:
|
||||||
|
case GGML_OP_CONT:
|
||||||
|
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
||||||
|
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
||||||
|
break;
|
||||||
case GGML_OP_SET_ROWS:
|
case GGML_OP_SET_ROWS:
|
||||||
supports_op = (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_I64);
|
supports_op = (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_I64);
|
||||||
break;
|
break;
|
||||||
case GGML_OP_GET_ROWS:
|
case GGML_OP_GET_ROWS:
|
||||||
if (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 ||
|
if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_I32 ||
|
||||||
op->src[0]->type == GGML_TYPE_I32 || ggml_webgpu_supported_qtype(op->src[0]->type)) {
|
ggml_webgpu_supported_qtype(src0->type)) {
|
||||||
supports_op = (op->type == GGML_TYPE_F32);
|
supports_op = (op->type == GGML_TYPE_F32);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
{
|
{
|
||||||
switch (op->src[1]->type) {
|
switch (src1->type) {
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
supports_op = (op->src[0]->type == GGML_TYPE_F16);
|
supports_op |= (src0->type == GGML_TYPE_F16);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
switch (op->src[0]->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
|
|
@ -1358,7 +1671,29 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
supports_op = op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
|
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
|
||||||
|
break;
|
||||||
|
case GGML_OP_ROPE:
|
||||||
|
supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
|
||||||
|
break;
|
||||||
|
case GGML_OP_GLU:
|
||||||
|
switch (ggml_get_glu_op(op)) {
|
||||||
|
case GGML_GLU_OP_REGLU:
|
||||||
|
case GGML_GLU_OP_GEGLU:
|
||||||
|
case GGML_GLU_OP_SWIGLU:
|
||||||
|
case GGML_GLU_OP_GEGLU_ERF:
|
||||||
|
case GGML_GLU_OP_GEGLU_QUICK:
|
||||||
|
supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
|
||||||
|
break;
|
||||||
|
case GGML_GLU_OP_SWIGLU_OAI:
|
||||||
|
supports_op = op->type == GGML_TYPE_F32;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case GGML_OP_SCALE:
|
||||||
|
supports_op = op->type == GGML_TYPE_F32;
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
|
|
@ -1484,8 +1819,13 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
||||||
ggml_webgpu_init_get_rows_pipeline(ctx);
|
ggml_webgpu_init_get_rows_pipeline(ctx);
|
||||||
ggml_webgpu_init_cpy_pipeline(ctx);
|
ggml_webgpu_init_cpy_pipeline(ctx);
|
||||||
ggml_webgpu_init_add_pipeline(ctx);
|
ggml_webgpu_init_add_pipeline(ctx);
|
||||||
|
ggml_webgpu_init_sub_pipeline(ctx);
|
||||||
ggml_webgpu_init_mul_pipeline(ctx);
|
ggml_webgpu_init_mul_pipeline(ctx);
|
||||||
|
ggml_webgpu_init_div_pipeline(ctx);
|
||||||
ggml_webgpu_init_rms_norm_pipeline(ctx);
|
ggml_webgpu_init_rms_norm_pipeline(ctx);
|
||||||
|
ggml_webgpu_init_rope_pipeline(ctx);
|
||||||
|
ggml_webgpu_init_glu_pipeline(ctx);
|
||||||
|
ggml_webgpu_init_scale_pipeline(ctx);
|
||||||
|
|
||||||
#ifdef GGML_WEBGPU_DEBUG
|
#ifdef GGML_WEBGPU_DEBUG
|
||||||
// Initialize debug buffers
|
// Initialize debug buffers
|
||||||
|
|
|
||||||
|
|
@ -1,44 +0,0 @@
|
||||||
#define(VARIANTS)
|
|
||||||
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"REPLS": {
|
|
||||||
"TYPE" : "f32",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"REPLS": {
|
|
||||||
"TYPE" : "f16",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
#end(VARIANTS)
|
|
||||||
|
|
||||||
#define(SHADER)
|
|
||||||
|
|
||||||
enable f16;
|
|
||||||
|
|
||||||
#include "binary_head.tmpl"
|
|
||||||
|
|
||||||
@group(0) @binding(0)
|
|
||||||
var<storage, read_write> src0: array<{{TYPE}}>;
|
|
||||||
|
|
||||||
@group(0) @binding(1)
|
|
||||||
var<storage, read_write> src1: array<{{TYPE}}>;
|
|
||||||
|
|
||||||
@group(0) @binding(2)
|
|
||||||
var<storage, read_write> dst: array<{{TYPE}}>;
|
|
||||||
|
|
||||||
@group(0) @binding(3)
|
|
||||||
var<uniform> params: Params;
|
|
||||||
|
|
||||||
override wg_size: u32;
|
|
||||||
@compute @workgroup_size(wg_size)
|
|
||||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
||||||
if (gid.x < params.ne) {
|
|
||||||
dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_index(gid.x)];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#end(SHADER)
|
|
||||||
|
|
@ -1,41 +0,0 @@
|
||||||
#define(VARIANTS)
|
|
||||||
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"REPLS": {
|
|
||||||
"TYPE" : "f32",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"REPLS": {
|
|
||||||
"TYPE" : "f16",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
#end(VARIANTS)
|
|
||||||
|
|
||||||
#define(SHADER)
|
|
||||||
|
|
||||||
enable f16;
|
|
||||||
|
|
||||||
#include "binary_head.tmpl"
|
|
||||||
|
|
||||||
@group(0) @binding(0)
|
|
||||||
var<storage, read_write> src0: array<{{TYPE}}>;
|
|
||||||
|
|
||||||
@group(0) @binding(1)
|
|
||||||
var<storage, read_write> src1: array<{{TYPE}}>;
|
|
||||||
|
|
||||||
@group(0) @binding(2)
|
|
||||||
var<uniform> params: Params;
|
|
||||||
|
|
||||||
override wg_size: u32;
|
|
||||||
@compute @workgroup_size(wg_size)
|
|
||||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
||||||
if (gid.x < params.ne) {
|
|
||||||
src0[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_index(gid.x)];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#end(SHADER)
|
|
||||||
|
|
@ -0,0 +1,188 @@
|
||||||
|
#define(VARIANTS)
|
||||||
|
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "add_f32",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
"OP": "+"
|
||||||
|
},
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "add_f16",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
"OP": "+"
|
||||||
|
},
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "add_f32_inplace",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
"OP": "+"
|
||||||
|
},
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "add_f16_inplace",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
"OP": "+"
|
||||||
|
},
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "mul_f32",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
"OP": "*"
|
||||||
|
},
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "mul_f16",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
"OP": "*"
|
||||||
|
},
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "mul_f32_inplace",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
"OP": "*"
|
||||||
|
},
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "mul_f16_inplace",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
"OP": "*"
|
||||||
|
},
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "sub_f32",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
"OP": "-"
|
||||||
|
},
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "sub_f16",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
"OP": "-"
|
||||||
|
},
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "sub_f32_inplace",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
"OP": "-"
|
||||||
|
},
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "sub_f16_inplace",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
"OP": "-"
|
||||||
|
},
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "div_f32",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
"OP": "/"
|
||||||
|
},
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "div_f16",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
"OP": "/"
|
||||||
|
},
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "div_f32_inplace",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
"OP": "/"
|
||||||
|
},
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "div_f16_inplace",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
"OP": "/"
|
||||||
|
},
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
#end(VARIANTS)
|
||||||
|
|
||||||
|
#define(DECLS)
|
||||||
|
|
||||||
|
#decl(NOT_INPLACE)
|
||||||
|
|
||||||
|
fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
|
||||||
|
dst[dst_i] = src0[src0_i] {{OP}} src1[src1_i];
|
||||||
|
}
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<storage, read_write> dst: array<{{TYPE}}>;
|
||||||
|
|
||||||
|
@group(0) @binding(3)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
#enddecl(NOT_INPLACE)
|
||||||
|
|
||||||
|
#decl(INPLACE)
|
||||||
|
|
||||||
|
fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
|
||||||
|
src0[dst_i] = src0[src0_i] {{OP}} src1[src1_i];
|
||||||
|
}
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
#enddecl(INPLACE)
|
||||||
|
|
||||||
|
#end(DECLS)
|
||||||
|
|
||||||
|
|
||||||
|
#define(SHADER)
|
||||||
|
|
||||||
|
enable f16;
|
||||||
|
|
||||||
|
#include "binary_head.tmpl"
|
||||||
|
|
||||||
|
@group(0) @binding(0)
|
||||||
|
var<storage, read_write> src0: array<{{TYPE}}>;
|
||||||
|
|
||||||
|
@group(0) @binding(1)
|
||||||
|
var<storage, read_write> src1: array<{{TYPE}}>;
|
||||||
|
|
||||||
|
DECLS
|
||||||
|
|
||||||
|
override wg_size: u32;
|
||||||
|
@compute @workgroup_size(wg_size)
|
||||||
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||||
|
if (gid.x < params.ne) {
|
||||||
|
update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#end(SHADER)
|
||||||
|
|
@ -0,0 +1,101 @@
|
||||||
|
#define(VARIANTS)
|
||||||
|
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"REPLS": {
|
||||||
|
"SRC_TYPE": "f32",
|
||||||
|
"DST_TYPE": "f32"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"REPLS": {
|
||||||
|
"SRC_TYPE": "f32",
|
||||||
|
"DST_TYPE": "f16"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"REPLS": {
|
||||||
|
"SRC_TYPE": "f16",
|
||||||
|
"DST_TYPE": "f16"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"REPLS": {
|
||||||
|
"SRC_TYPE": "f16",
|
||||||
|
"DST_TYPE": "f32"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
#end(VARIANTS)
|
||||||
|
|
||||||
|
#define(SHADER)
|
||||||
|
enable f16;
|
||||||
|
|
||||||
|
@group(0) @binding(0)
|
||||||
|
var<storage, read_write> src: array<{{SRC_TYPE}}>;
|
||||||
|
|
||||||
|
@group(0) @binding(1)
|
||||||
|
var<storage, read_write> dst: array<{{DST_TYPE}}>;
|
||||||
|
|
||||||
|
struct Params {
|
||||||
|
ne: u32, // total number of elements
|
||||||
|
offset_src: u32, // in elements
|
||||||
|
offset_dst: u32, // in elements
|
||||||
|
|
||||||
|
// Strides (in elements) — may be permuted
|
||||||
|
stride_src0: u32,
|
||||||
|
stride_src1: u32,
|
||||||
|
stride_src2: u32,
|
||||||
|
stride_src3: u32,
|
||||||
|
|
||||||
|
stride_dst0: u32,
|
||||||
|
stride_dst1: u32,
|
||||||
|
stride_dst2: u32,
|
||||||
|
stride_dst3: u32,
|
||||||
|
|
||||||
|
// Logical shapes
|
||||||
|
src_ne0: u32,
|
||||||
|
src_ne1: u32,
|
||||||
|
src_ne2: u32,
|
||||||
|
|
||||||
|
dst_ne0: u32,
|
||||||
|
dst_ne1: u32,
|
||||||
|
dst_ne2: u32
|
||||||
|
};
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
override wg_size: u32;
|
||||||
|
@compute @workgroup_size(wg_size)
|
||||||
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||||
|
if (gid.x >= params.ne) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
var i = gid.x;
|
||||||
|
let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0);
|
||||||
|
i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0);
|
||||||
|
let i2 = i / (params.src_ne1 * params.src_ne0);
|
||||||
|
i = i % (params.src_ne1 * params.src_ne0);
|
||||||
|
let i1 = i / params.src_ne0;
|
||||||
|
let i0 = i % params.src_ne0;
|
||||||
|
|
||||||
|
var j = gid.x;
|
||||||
|
let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
|
||||||
|
j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
|
||||||
|
let j2 = j / (params.dst_ne1 * params.dst_ne0);
|
||||||
|
j = j % (params.dst_ne1 * params.dst_ne0);
|
||||||
|
let j1 = j / params.dst_ne0;
|
||||||
|
let j0 = j % params.dst_ne0;
|
||||||
|
|
||||||
|
let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
|
||||||
|
i2 * params.stride_src2 + i3 * params.stride_src3;
|
||||||
|
|
||||||
|
let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 +
|
||||||
|
j2 * params.stride_dst2 + j3 * params.stride_dst3;
|
||||||
|
|
||||||
|
dst[params.offset_dst + dst_idx] = {{DST_TYPE}}((src[params.offset_src + src_idx]));
|
||||||
|
}
|
||||||
|
#end(SHADER)
|
||||||
|
|
@ -1,60 +0,0 @@
|
||||||
enable f16;
|
|
||||||
|
|
||||||
@group(0) @binding(0)
|
|
||||||
var<storage, read_write> src: array<f32>;
|
|
||||||
|
|
||||||
@group(0) @binding(1)
|
|
||||||
var<storage, read_write> dst: array<f16>;
|
|
||||||
|
|
||||||
struct Params {
|
|
||||||
ne: u32, // total number of elements
|
|
||||||
offset_src: u32, // in elements
|
|
||||||
offset_dst: u32, // in elements
|
|
||||||
|
|
||||||
// Strides (in elements) — may be permuted
|
|
||||||
stride_src0: u32,
|
|
||||||
stride_src1: u32,
|
|
||||||
stride_src2: u32,
|
|
||||||
stride_src3: u32,
|
|
||||||
|
|
||||||
stride_dst0: u32,
|
|
||||||
stride_dst1: u32,
|
|
||||||
stride_dst2: u32,
|
|
||||||
stride_dst3: u32,
|
|
||||||
|
|
||||||
// Logical shape (same for both tensors)
|
|
||||||
ne0: u32,
|
|
||||||
ne1: u32,
|
|
||||||
ne2: u32,
|
|
||||||
ne3: u32,
|
|
||||||
};
|
|
||||||
|
|
||||||
@group(0) @binding(2)
|
|
||||||
var<uniform> params: Params;
|
|
||||||
|
|
||||||
override wg_size: u32;
|
|
||||||
@compute @workgroup_size(wg_size)
|
|
||||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
||||||
if (gid.x >= params.ne) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
var i = gid.x;
|
|
||||||
|
|
||||||
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
|
|
||||||
i = i % (params.ne2 * params.ne1 * params.ne0);
|
|
||||||
|
|
||||||
let i2 = i / (params.ne1 * params.ne0);
|
|
||||||
i = i % (params.ne1 * params.ne0);
|
|
||||||
|
|
||||||
let i1 = i / params.ne0;
|
|
||||||
let i0 = i % params.ne0;
|
|
||||||
|
|
||||||
let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
|
|
||||||
i2 * params.stride_src2 + i3 * params.stride_src3;
|
|
||||||
|
|
||||||
let dst_idx = i0 * params.stride_dst0 + i1 * params.stride_dst1 +
|
|
||||||
i2 * params.stride_dst2 + i3 * params.stride_dst3;
|
|
||||||
|
|
||||||
dst[params.offset_dst + dst_idx] = f16(src[params.offset_src + src_idx]);
|
|
||||||
}
|
|
||||||
|
|
@ -88,15 +88,20 @@ def generate_variants(fname, input_dir, output_dir, outfile):
|
||||||
raise ValueError(f"DECLS key '{key}' not found.")
|
raise ValueError(f"DECLS key '{key}' not found.")
|
||||||
decls_code += decls_map[key] + "\n\n"
|
decls_code += decls_map[key] + "\n\n"
|
||||||
|
|
||||||
shader_variant = replace_placeholders(shader_template, variant["REPLS"])
|
final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template)
|
||||||
final_shader = re.sub(r'\bDECLS\b', decls_code, shader_variant)
|
if "REPLS" in variant:
|
||||||
|
final_shader = replace_placeholders(final_shader, variant["REPLS"])
|
||||||
final_shader = expand_includes(final_shader, input_dir)
|
final_shader = expand_includes(final_shader, input_dir)
|
||||||
|
|
||||||
if "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]:
|
if "SHADER_NAME" in variant:
|
||||||
|
output_name = variant["SHADER_NAME"]
|
||||||
|
elif "SHADER_SUFFIX" in variant:
|
||||||
|
output_name = f"{shader_base_name}_" + variant["SHADER_SUFFIX"]
|
||||||
|
elif "REPLS" in variant and "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]:
|
||||||
output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]])
|
output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]])
|
||||||
elif "TYPE_SUFFIX" in variant["REPLS"]:
|
elif "REPLS" in variant and "SRC_TYPE" in variant["REPLS"] and "DST_TYPE" in variant["REPLS"]:
|
||||||
output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE_SUFFIX"]
|
output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC_TYPE"], variant["REPLS"]["DST_TYPE"]])
|
||||||
elif "TYPE" in variant["REPLS"]:
|
elif "REPLS" in variant and "TYPE" in variant["REPLS"]:
|
||||||
output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"]
|
output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"]
|
||||||
else:
|
else:
|
||||||
output_name = shader_base_name
|
output_name = shader_base_name
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,9 @@
|
||||||
|
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
|
"SHADER_SUFFIX": "f32_vec",
|
||||||
"REPLS": {
|
"REPLS": {
|
||||||
"TYPE" : "vec4<f32>",
|
"TYPE" : "vec4<f32>",
|
||||||
"TYPE_SUFFIX": "f32_vec",
|
|
||||||
"DST_TYPE": "vec4<f32>",
|
"DST_TYPE": "vec4<f32>",
|
||||||
"BLOCK_SIZE": 4
|
"BLOCK_SIZE": 4
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,323 @@
|
||||||
|
#define(VARIANTS)
|
||||||
|
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "reglu_f32",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
},
|
||||||
|
"DECLS": ["NO_SPLIT", "REGLU"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "reglu_f32_split",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
},
|
||||||
|
"DECLS": ["SPLIT", "REGLU"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "reglu_f16",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
},
|
||||||
|
"DECLS": ["NO_SPLIT", "REGLU"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "reglu_f16_split",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
},
|
||||||
|
"DECLS": ["SPLIT", "REGLU"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "geglu_f32",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
},
|
||||||
|
"DECLS": ["NO_SPLIT", "GEGLU"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "geglu_f32_split",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
},
|
||||||
|
"DECLS": ["SPLIT", "GEGLU"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "geglu_f16",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
},
|
||||||
|
"DECLS": ["NO_SPLIT", "GEGLU"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "geglu_f16_split",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
},
|
||||||
|
"DECLS": ["SPLIT", "GEGLU"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "swiglu_f32",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
},
|
||||||
|
"DECLS": ["NO_SPLIT", "SWIGLU"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "swiglu_f32_split",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
},
|
||||||
|
"DECLS": ["SPLIT", "SWIGLU"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "swiglu_f16",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
},
|
||||||
|
"DECLS": ["NO_SPLIT", "SWIGLU"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "swiglu_f16_split",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
},
|
||||||
|
"DECLS": ["SPLIT", "SWIGLU"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "swiglu_oai_f32",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
},
|
||||||
|
"DECLS": ["NO_SPLIT", "SWIGLU_OAI"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "swiglu_oai_f32_split",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
},
|
||||||
|
"DECLS": ["SPLIT", "SWIGLU_OAI"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "geglu_erf_f32",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
},
|
||||||
|
"DECLS": ["NO_SPLIT", "GEGLU_ERF"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "geglu_erf_f32_split",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
},
|
||||||
|
"DECLS": ["SPLIT", "GEGLU_ERF"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "geglu_erf_f16",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
},
|
||||||
|
"DECLS": ["NO_SPLIT", "GEGLU_ERF"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "geglu_erf_f16_split",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
},
|
||||||
|
"DECLS": ["SPLIT", "GEGLU_ERF"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "geglu_quick_f32",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
},
|
||||||
|
"DECLS": ["NO_SPLIT", "GEGLU_QUICK"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "geglu_quick_f32_split",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
},
|
||||||
|
"DECLS": ["SPLIT", "GEGLU_QUICK"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "geglu_quick_f16",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
},
|
||||||
|
"DECLS": ["NO_SPLIT", "GEGLU_QUICK"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "geglu_quick_f16_split",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
},
|
||||||
|
"DECLS": ["SPLIT", "GEGLU_QUICK"]
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
#end(VARIANTS)
|
||||||
|
|
||||||
|
#define(DECLS)
|
||||||
|
|
||||||
|
#decl(REGLU)
|
||||||
|
fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
|
||||||
|
return max(a, 0) * b;
|
||||||
|
}
|
||||||
|
#enddecl(REGLU)
|
||||||
|
|
||||||
|
#decl(GEGLU)
|
||||||
|
const SQRT_2_OVER_PI: {{TYPE}} = 0.79788456080286535587989211986876;
|
||||||
|
const GELU_COEF_A: {{TYPE}} = 0.044715;
|
||||||
|
|
||||||
|
fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
|
||||||
|
let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a);
|
||||||
|
return 0.5 * a * (2.0 - 2.0 / (exp(2 * val) + 1)) * b;
|
||||||
|
}
|
||||||
|
#enddecl(GEGLU)
|
||||||
|
|
||||||
|
#decl(SWIGLU)
|
||||||
|
fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
|
||||||
|
return a / (1.0 + exp(-a)) * b;
|
||||||
|
}
|
||||||
|
#enddecl(SWIGLU)
|
||||||
|
|
||||||
|
#decl(SWIGLU_OAI)
|
||||||
|
fn op(a: f32, b: f32) -> f32 {
|
||||||
|
let xi = min(a, params.limit);
|
||||||
|
let gi = max(min(b, params.limit), -params.limit);
|
||||||
|
var out_glu = xi / (1.0 + exp(-xi * params.alpha));
|
||||||
|
out_glu = out_glu * (1.0 + gi);
|
||||||
|
return out_glu;
|
||||||
|
}
|
||||||
|
#enddecl(SWIGLU_OAI)
|
||||||
|
|
||||||
|
#decl(GEGLU_ERF)
|
||||||
|
const p_erf: {{TYPE}} = 0.3275911;
|
||||||
|
const a1_erf: {{TYPE}} = 0.254829592;
|
||||||
|
const a2_erf: {{TYPE}} = -0.284496736;
|
||||||
|
const a3_erf: {{TYPE}} = 1.421413741;
|
||||||
|
const a4_erf: {{TYPE}} = -1.453152027;
|
||||||
|
const a5_erf: {{TYPE}} = 1.061405429;
|
||||||
|
const SQRT_2_INV: {{TYPE}} = 0.7071067811865476;
|
||||||
|
|
||||||
|
fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
|
||||||
|
let a_div_sqr2 = a * SQRT_2_INV;
|
||||||
|
let sign_x = sign(a_div_sqr2);
|
||||||
|
let x = abs(a_div_sqr2);
|
||||||
|
let t = 1.0 / (1.0 + p_erf * x);
|
||||||
|
let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x));
|
||||||
|
let erf_approx = sign_x * y;
|
||||||
|
return 0.5 * a * (1.0 + erf_approx) * b;
|
||||||
|
}
|
||||||
|
#enddecl(GEGLU_ERF)
|
||||||
|
|
||||||
|
#decl(GEGLU_QUICK)
|
||||||
|
const GELU_QUICK_COEF: {{TYPE}} = -1.702;
|
||||||
|
|
||||||
|
fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
|
||||||
|
return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b;
|
||||||
|
}
|
||||||
|
#enddecl(GEGLU_QUICK)
|
||||||
|
|
||||||
|
#decl(NO_SPLIT)
|
||||||
|
@group(0) @binding(1)
|
||||||
|
var<storage, read_write> dst: array<{{TYPE}}>;
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
fn a_value(base: u32) -> {{TYPE}} {
|
||||||
|
let offset: u32 = select(0, params.ne0, params.swapped != 0);
|
||||||
|
return src0[base + offset];
|
||||||
|
}
|
||||||
|
|
||||||
|
fn b_value(base: u32) -> {{TYPE}} {
|
||||||
|
let offset: u32 = select(params.ne0, 0, params.swapped != 0);
|
||||||
|
return src0[base + offset];
|
||||||
|
}
|
||||||
|
#enddecl(NO_SPLIT)
|
||||||
|
|
||||||
|
#decl(SPLIT)
|
||||||
|
@group(0) @binding(1)
|
||||||
|
var<storage, read_write> src1: array<{{TYPE}}>;
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<storage, read_write> dst: array<{{TYPE}}>;
|
||||||
|
|
||||||
|
@group(0) @binding(3)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
fn a_value(base: u32) -> {{TYPE}} {
|
||||||
|
return src0[base];
|
||||||
|
}
|
||||||
|
|
||||||
|
fn b_value(base: u32) -> {{TYPE}} {
|
||||||
|
return src1[base];
|
||||||
|
}
|
||||||
|
#enddecl(SPLIT)
|
||||||
|
|
||||||
|
#end(DECLS)
|
||||||
|
|
||||||
|
#define(SHADER)
|
||||||
|
|
||||||
|
enable f16;
|
||||||
|
|
||||||
|
struct Params {
|
||||||
|
offset_src0: u32,
|
||||||
|
offset_src1: u32,
|
||||||
|
offset_dst: u32,
|
||||||
|
|
||||||
|
// Strides (in elements)
|
||||||
|
stride_src01: u32,
|
||||||
|
stride_src02: u32,
|
||||||
|
stride_src03: u32,
|
||||||
|
|
||||||
|
stride_src11: u32,
|
||||||
|
stride_src12: u32,
|
||||||
|
stride_src13: u32,
|
||||||
|
|
||||||
|
stride_dst1: u32,
|
||||||
|
stride_dst2: u32,
|
||||||
|
stride_dst3: u32,
|
||||||
|
|
||||||
|
// shape of dst
|
||||||
|
ne: u32,
|
||||||
|
ne0: u32,
|
||||||
|
ne1: u32,
|
||||||
|
ne2: u32,
|
||||||
|
|
||||||
|
swapped: u32,
|
||||||
|
alpha: f32,
|
||||||
|
limit: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
@group(0) @binding(0)
|
||||||
|
var<storage, read_write> src0: array<{{TYPE}}>;
|
||||||
|
|
||||||
|
DECLS
|
||||||
|
|
||||||
|
override wg_size: u32;
|
||||||
|
@compute @workgroup_size(wg_size)
|
||||||
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||||
|
if (gid.x >= params.ne) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
var i = gid.x;
|
||||||
|
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
|
||||||
|
i = i % (params.ne2 * params.ne1 * params.ne0);
|
||||||
|
let i2 = i / (params.ne1 * params.ne0);
|
||||||
|
i = i % (params.ne1 * params.ne0);
|
||||||
|
let i1 = i / params.ne0;
|
||||||
|
let i0 = i % params.ne0;
|
||||||
|
|
||||||
|
let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0;
|
||||||
|
let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0;
|
||||||
|
let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;
|
||||||
|
|
||||||
|
dst[i_dst] = op(a_value(i_a), b_value(i_b));
|
||||||
|
}
|
||||||
|
|
||||||
|
#end(SHADER)
|
||||||
|
|
@ -1,44 +0,0 @@
|
||||||
#define(VARIANTS)
|
|
||||||
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"REPLS": {
|
|
||||||
"TYPE" : "f32",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"REPLS": {
|
|
||||||
"TYPE" : "f16",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
#end(VARIANTS)
|
|
||||||
|
|
||||||
#define(SHADER)
|
|
||||||
|
|
||||||
enable f16;
|
|
||||||
|
|
||||||
#include "binary_head.tmpl"
|
|
||||||
|
|
||||||
@group(0) @binding(0)
|
|
||||||
var<storage, read_write> src0: array<{{TYPE}}>;
|
|
||||||
|
|
||||||
@group(0) @binding(1)
|
|
||||||
var<storage, read_write> src1: array<{{TYPE}}>;
|
|
||||||
|
|
||||||
@group(0) @binding(2)
|
|
||||||
var<storage, read_write> dst: array<{{TYPE}}>;
|
|
||||||
|
|
||||||
@group(0) @binding(3)
|
|
||||||
var<uniform> params: Params;
|
|
||||||
|
|
||||||
override wg_size: u32;
|
|
||||||
@compute @workgroup_size(wg_size)
|
|
||||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
||||||
if (gid.x < params.ne) {
|
|
||||||
dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] * src1[params.offset_src1 + src1_index(gid.x)];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#end(SHADER)
|
|
||||||
|
|
@ -1,41 +0,0 @@
|
||||||
#define(VARIANTS)
|
|
||||||
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"REPLS": {
|
|
||||||
"TYPE" : "f32",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"REPLS": {
|
|
||||||
"TYPE" : "f16",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
#end(VARIANTS)
|
|
||||||
|
|
||||||
#define(SHADER)
|
|
||||||
|
|
||||||
enable f16;
|
|
||||||
|
|
||||||
#include "binary_head.tmpl"
|
|
||||||
|
|
||||||
@group(0) @binding(0)
|
|
||||||
var<storage, read_write> src0: array<{{TYPE}}>;
|
|
||||||
|
|
||||||
@group(0) @binding(1)
|
|
||||||
var<storage, read_write> src1: array<{{TYPE}}>;
|
|
||||||
|
|
||||||
@group(0) @binding(2)
|
|
||||||
var<uniform> params: Params;
|
|
||||||
|
|
||||||
override wg_size: u32;
|
|
||||||
@compute @workgroup_size(wg_size)
|
|
||||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
||||||
if (gid.x < params.ne) {
|
|
||||||
src0[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] * src1[params.offset_src1 + src1_index(gid.x)];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#end(SHADER)
|
|
||||||
|
|
@ -1,9 +1,48 @@
|
||||||
@group(0) @binding(0)
|
#define(VARIANTS)
|
||||||
var<storage, read_write> src: array<f32>;
|
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "inplace",
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
#end(VARIANTS)
|
||||||
|
|
||||||
|
#define(DECLS)
|
||||||
|
|
||||||
|
#decl(NOT_INPLACE)
|
||||||
|
|
||||||
|
fn update(src_offset: u32, dst_offset: u32, scale: f32) {
|
||||||
|
dst[dst_offset] = scale * src[src_offset];
|
||||||
|
}
|
||||||
|
|
||||||
@group(0) @binding(1)
|
@group(0) @binding(1)
|
||||||
var<storage, read_write> dst: array<f32>;
|
var<storage, read_write> dst: array<f32>;
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
#enddecl(NOT_INPLACE)
|
||||||
|
|
||||||
|
#decl(INPLACE)
|
||||||
|
|
||||||
|
fn update(src_offset: u32, dst_offset: u32, scale: f32) {
|
||||||
|
src[dst_offset] = scale * src[src_offset];
|
||||||
|
}
|
||||||
|
|
||||||
|
@group(0) @binding(1)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
#enddecl(INPLACE)
|
||||||
|
|
||||||
|
#end(DECLS)
|
||||||
|
|
||||||
|
#define(SHADER)
|
||||||
|
|
||||||
struct Params {
|
struct Params {
|
||||||
offset_src: u32, // in elements
|
offset_src: u32, // in elements
|
||||||
offset_dst: u32, // in elements
|
offset_dst: u32, // in elements
|
||||||
|
|
@ -23,11 +62,13 @@ struct Params {
|
||||||
ne2: u32,
|
ne2: u32,
|
||||||
ne3: u32,
|
ne3: u32,
|
||||||
|
|
||||||
eps: u32
|
eps: f32
|
||||||
};
|
};
|
||||||
|
|
||||||
@group(0) @binding(2)
|
@group(0) @binding(0)
|
||||||
var<uniform> params: Params;
|
var<storage, read_write> src: array<f32>;
|
||||||
|
|
||||||
|
DECLS
|
||||||
|
|
||||||
override wg_size: u32;
|
override wg_size: u32;
|
||||||
@compute @workgroup_size(wg_size)
|
@compute @workgroup_size(wg_size)
|
||||||
|
|
@ -49,9 +90,9 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||||
for (var j: u32 = 0; j < params.ne0; j++) {
|
for (var j: u32 = 0; j < params.ne0; j++) {
|
||||||
sum += src[i_src_row + j] * src[i_src_row + j];
|
sum += src[i_src_row + j] * src[i_src_row + j];
|
||||||
}
|
}
|
||||||
let eps = bitcast<f32>(params.eps);
|
let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps);
|
||||||
let scale = 1.0/sqrt(sum/f32(params.ne0) + eps);
|
|
||||||
for (var j: u32 = 0; j < params.ne0; j++) {
|
for (var j: u32 = 0; j < params.ne0; j++) {
|
||||||
dst[i_dst_row + j] = scale * src[i_src_row + j];
|
update(i_src_row + j, i_dst_row + j, scale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#end(SHADER)
|
||||||
|
|
|
||||||
|
|
@ -1,48 +0,0 @@
|
||||||
@group(0) @binding(0)
|
|
||||||
var<storage, read_write> a: array<f32>;
|
|
||||||
|
|
||||||
struct Params {
|
|
||||||
offset: u32, // in elements
|
|
||||||
|
|
||||||
// Strides (in elements)
|
|
||||||
stride1: u32,
|
|
||||||
stride2: u32,
|
|
||||||
stride3: u32,
|
|
||||||
|
|
||||||
// Shape
|
|
||||||
ne0: u32,
|
|
||||||
ne1: u32,
|
|
||||||
ne2: u32,
|
|
||||||
ne3: u32,
|
|
||||||
|
|
||||||
eps: u32
|
|
||||||
};
|
|
||||||
|
|
||||||
@group(0) @binding(1)
|
|
||||||
var<uniform> params: Params;
|
|
||||||
|
|
||||||
override wg_size: u32;
|
|
||||||
@compute @workgroup_size(wg_size)
|
|
||||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
||||||
if (gid.x >= params.ne1 * params.ne2 * params.ne3) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// one thread per row
|
|
||||||
var i = gid.x;
|
|
||||||
let i3 = i / (params.ne2 * params.ne1);
|
|
||||||
i = i % (params.ne2 * params.ne1);
|
|
||||||
let i2 = i / params.ne1;
|
|
||||||
let i1 = i % params.ne1;
|
|
||||||
let i_row = params.offset + i3 * params.stride3 + i2 * params.stride2 + i1 * params.stride1;
|
|
||||||
|
|
||||||
var sum = 0.0f;
|
|
||||||
for (var j: u32 = 0; j < params.ne0; j++) {
|
|
||||||
sum += a[i_row + j] * a[i_row + j];
|
|
||||||
}
|
|
||||||
let eps = bitcast<f32>(params.eps);
|
|
||||||
let scale = 1.0/sqrt(sum/f32(params.ne0) + eps);
|
|
||||||
for (var j: u32 = 0; j < params.ne0; j++) {
|
|
||||||
a[i_row + j] = scale * a[i_row + j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -0,0 +1,282 @@
|
||||||
|
#define(VARIANTS)
|
||||||
|
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
},
|
||||||
|
"DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "f32_inplace",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
},
|
||||||
|
"DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
},
|
||||||
|
"DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "f16_inplace",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
},
|
||||||
|
"DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "f32_ff",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
},
|
||||||
|
"DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "f32_ff_inplace",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
},
|
||||||
|
"DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "f16_ff",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
},
|
||||||
|
"DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_SUFFIX": "f16_ff_inplace",
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
},
|
||||||
|
"DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
#end(VARIANTS)
|
||||||
|
|
||||||
|
#define(DECLS)
|
||||||
|
|
||||||
|
#decl(ROTATE)
|
||||||
|
fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
|
||||||
|
dst[i_dst0] = {{TYPE}}(out0);
|
||||||
|
dst[i_dst1] = {{TYPE}}(out1);
|
||||||
|
}
|
||||||
|
#enddecl(ROTATE)
|
||||||
|
|
||||||
|
#decl(ROTATE_INPLACE)
|
||||||
|
fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
|
||||||
|
src0[i_dst0] = {{TYPE}}(out0);
|
||||||
|
src0[i_dst1] = {{TYPE}}(out1);
|
||||||
|
}
|
||||||
|
#enddecl(ROTATE_INPLACE)
|
||||||
|
|
||||||
|
#decl(NO_FF_FUNC)
|
||||||
|
fn freq_factor(i: u32) -> f32 {
|
||||||
|
return 1.0f;
|
||||||
|
}
|
||||||
|
#enddecl(NO_FF_FUNC)
|
||||||
|
|
||||||
|
#decl(FF_FUNC)
|
||||||
|
fn freq_factor(i: u32) -> f32 {
|
||||||
|
return src2[params.offset_src2 + i/2];
|
||||||
|
}
|
||||||
|
#enddecl(FF_FUNC)
|
||||||
|
|
||||||
|
#decl(NO_FF_BINDINGS)
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<storage, read_write> dst: array<{{TYPE}}>;
|
||||||
|
|
||||||
|
@group(0) @binding(3)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
#enddecl(NO_FF_BINDINGS)
|
||||||
|
|
||||||
|
#decl(NO_FF_BINDINGS_INPLACE)
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
#enddecl(NO_FF_BINDINGS_INPLACE)
|
||||||
|
|
||||||
|
#decl(FF_BINDINGS)
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<storage, read_write> src2: array<f32>;
|
||||||
|
|
||||||
|
@group(0) @binding(3)
|
||||||
|
var<storage, read_write> dst: array<{{TYPE}}>;
|
||||||
|
|
||||||
|
@group(0) @binding(4)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
#enddecl(FF_BINDINGS)
|
||||||
|
|
||||||
|
#decl(FF_BINDINGS_INPLACE)
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<storage, read_write> src2: array<f32>;
|
||||||
|
|
||||||
|
@group(0) @binding(3)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
#enddecl(FF_BINDINGS_INPLACE)
|
||||||
|
|
||||||
|
#end(DECLS)
|
||||||
|
|
||||||
|
#define(SHADER)
|
||||||
|
|
||||||
|
enable f16;
|
||||||
|
|
||||||
|
struct Params {
|
||||||
|
offset_src0: u32,
|
||||||
|
offset_src1: u32,
|
||||||
|
offset_src2: u32,
|
||||||
|
offset_dst: u32,
|
||||||
|
|
||||||
|
// Strides (in elements)
|
||||||
|
stride_src01: u32,
|
||||||
|
stride_src02: u32,
|
||||||
|
stride_src03: u32,
|
||||||
|
|
||||||
|
stride_dst1: u32,
|
||||||
|
stride_dst2: u32,
|
||||||
|
stride_dst3: u32,
|
||||||
|
|
||||||
|
n_threads: u32,
|
||||||
|
ne0: u32,
|
||||||
|
ne1: u32,
|
||||||
|
ne2: u32,
|
||||||
|
|
||||||
|
n_dims: u32,
|
||||||
|
mode: u32,
|
||||||
|
theta_scale: f32,
|
||||||
|
attn_factor: f32,
|
||||||
|
freq_scale: f32,
|
||||||
|
ext_factor: f32,
|
||||||
|
corr_dim0: f32,
|
||||||
|
corr_dim1: f32,
|
||||||
|
sections0: u32,
|
||||||
|
sections1: u32,
|
||||||
|
sections2: u32,
|
||||||
|
sections3: u32
|
||||||
|
};
|
||||||
|
|
||||||
|
@group(0) @binding(0)
|
||||||
|
var<storage, read_write> src0: array<{{TYPE}}>;
|
||||||
|
|
||||||
|
@group(0) @binding(1)
|
||||||
|
var<storage, read_write> src1: array<i32>;
|
||||||
|
|
||||||
|
DECLS
|
||||||
|
|
||||||
|
fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 {
|
||||||
|
let y = (f32(i / 2) - low) / max(0.001f, high - low);
|
||||||
|
return 1.0f - min(1.0f, max(0.0f, y));
|
||||||
|
}
|
||||||
|
|
||||||
|
// returns vector of (cos_theta, sin_theta)
|
||||||
|
// TODO: check performance of instantiating once on the CPU and passed as buffer, since it's repeated per-row
|
||||||
|
fn rope_yarn(theta_extrap: f32, i: u32) -> vec2<f32> {
|
||||||
|
var mscale = params.attn_factor;
|
||||||
|
var theta = params.freq_scale * theta_extrap;
|
||||||
|
if (params.ext_factor != 0.0f) {
|
||||||
|
let ramp_mix = rope_yarn_ramp(params.corr_dim0, params.corr_dim1, i) * params.ext_factor;
|
||||||
|
theta = theta * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
||||||
|
mscale *= 1.0f + 0.1f * log(1.0f / params.freq_scale);
|
||||||
|
}
|
||||||
|
return vec2<f32>(cos(theta) * mscale, sin(theta) * mscale);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn pair_base(i0: u32, div_2: bool) -> u32 {
|
||||||
|
if (div_2) {
|
||||||
|
return i0 / 2;
|
||||||
|
} else {
|
||||||
|
return i0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn pair_offset(is_neox: bool, is_mrope: bool, is_vision: bool) -> u32 {
|
||||||
|
if (is_vision) {
|
||||||
|
return params.n_dims;
|
||||||
|
} else if (is_neox || is_mrope) {
|
||||||
|
return params.n_dims / 2;
|
||||||
|
} else {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override wg_size: u32;
|
||||||
|
@compute @workgroup_size(wg_size)
|
||||||
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||||
|
// two elements per thread
|
||||||
|
if (gid.x >= params.n_threads) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let is_neox = bool(params.mode & 2);
|
||||||
|
let is_mrope = bool(params.mode & 8);
|
||||||
|
let is_vision = params.mode == 24;
|
||||||
|
|
||||||
|
var i = gid.x * 2; // start index for this thread
|
||||||
|
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
|
||||||
|
i = i % (params.ne2 * params.ne1 * params.ne0);
|
||||||
|
let i2 = i / (params.ne1 * params.ne0);
|
||||||
|
i = i % (params.ne1 * params.ne0);
|
||||||
|
let i1 = i / params.ne0;
|
||||||
|
let i0 = i % params.ne0;
|
||||||
|
|
||||||
|
let i_src_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01;
|
||||||
|
let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
|
||||||
|
|
||||||
|
if (i0 >= params.n_dims && !is_vision) {
|
||||||
|
let i_src = i_src_row + i0;
|
||||||
|
let i_dst = i_dst_row + i0;
|
||||||
|
rotate(i_dst, i_dst + 1, f32(src0[i_src]), f32(src0[i_src + 1]));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
var theta_base_mult: u32 = 0;
|
||||||
|
var theta_scale_pwr: u32 = i0 / 2;
|
||||||
|
if (is_mrope) {
|
||||||
|
let sect_dims = params.sections0 + params.sections1 + params.sections2 + params.sections3;
|
||||||
|
let sec_w = params.sections1 + params.sections0;
|
||||||
|
let sec_e = params.sections2 + sec_w;
|
||||||
|
let sector = (i0 / 2) % sect_dims;
|
||||||
|
if (sector >= params.sections0 && sector < sec_w) {
|
||||||
|
theta_base_mult = 1;
|
||||||
|
if (is_vision) {
|
||||||
|
theta_scale_pwr = sector - params.sections0;
|
||||||
|
}
|
||||||
|
} else if (sector >= sec_w && sector < sec_e) {
|
||||||
|
theta_base_mult = 2;
|
||||||
|
if (is_vision) {
|
||||||
|
theta_scale_pwr = sector - sec_w;
|
||||||
|
}
|
||||||
|
} else if (sector >= sec_e) {
|
||||||
|
if (is_vision) {
|
||||||
|
theta_scale_pwr = sector - sec_e;
|
||||||
|
theta_scale_pwr = (i0 / 2) % sec_e;
|
||||||
|
}
|
||||||
|
theta_base_mult = 3;
|
||||||
|
} else if (is_vision) {
|
||||||
|
theta_scale_pwr = sector;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let theta_base = f32(src1[params.offset_src1 + i2 + params.ne2 * theta_base_mult]) * pow(params.theta_scale, f32(theta_scale_pwr));
|
||||||
|
let thetas = rope_yarn(theta_base/freq_factor(i0), i0);
|
||||||
|
|
||||||
|
let i_src = i_src_row + pair_base(i0, is_neox || is_mrope || is_vision);
|
||||||
|
let i_dst = i_dst_row + pair_base(i0, is_neox || is_mrope || is_vision);
|
||||||
|
|
||||||
|
let x0 = f32(src0[i_src]);
|
||||||
|
let x1 = f32(src0[i_src + pair_offset(is_neox, is_mrope, is_vision)]);
|
||||||
|
rotate(i_dst, i_dst + pair_offset(is_neox, is_mrope, is_vision), x0 * thetas.x - x1 * thetas.y, x0 * thetas.y + x1 * thetas.x);
|
||||||
|
}
|
||||||
|
|
||||||
|
#end(SHADER)
|
||||||
|
|
@ -0,0 +1,90 @@
|
||||||
|
#define(VARIANTS)
|
||||||
|
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "scale_f32",
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "scale_f32_inplace",
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
#end(VARIANTS)
|
||||||
|
|
||||||
|
#define(DECLS)
|
||||||
|
|
||||||
|
#decl(NOT_INPLACE)
|
||||||
|
@group(0) @binding(1)
|
||||||
|
var<storage, read_write> dst: array<f32>;
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
fn store_scale(val: f32, offset: u32) {
|
||||||
|
dst[offset] = val;
|
||||||
|
}
|
||||||
|
#enddecl(NOT_INPLACE)
|
||||||
|
|
||||||
|
#decl(INPLACE)
|
||||||
|
@group(0) @binding(1)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
fn store_scale(val: f32, offset: u32) {
|
||||||
|
src[offset] = val;
|
||||||
|
}
|
||||||
|
#enddecl(INPLACE)
|
||||||
|
|
||||||
|
#end(DECLS)
|
||||||
|
|
||||||
|
#define(SHADER)
|
||||||
|
|
||||||
|
struct Params {
|
||||||
|
offset_src: u32,
|
||||||
|
offset_dst: u32,
|
||||||
|
|
||||||
|
// Strides (in elements)
|
||||||
|
stride_src1: u32,
|
||||||
|
stride_src2: u32,
|
||||||
|
stride_src3: u32,
|
||||||
|
|
||||||
|
stride_dst1: u32,
|
||||||
|
stride_dst2: u32,
|
||||||
|
stride_dst3: u32,
|
||||||
|
|
||||||
|
ne: u32,
|
||||||
|
ne0: u32,
|
||||||
|
ne1: u32,
|
||||||
|
ne2: u32,
|
||||||
|
|
||||||
|
scale: f32,
|
||||||
|
bias: f32
|
||||||
|
};
|
||||||
|
|
||||||
|
@group(0) @binding(0)
|
||||||
|
var<storage, read_write> src: array<f32>;
|
||||||
|
|
||||||
|
DECLS
|
||||||
|
|
||||||
|
override wg_size: u32;
|
||||||
|
@compute @workgroup_size(wg_size)
|
||||||
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||||
|
if (gid.x >= params.ne) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
var i = gid.x;
|
||||||
|
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
|
||||||
|
i = i % (params.ne2 * params.ne1 * params.ne0);
|
||||||
|
let i2 = i / (params.ne1 * params.ne0);
|
||||||
|
i = i % (params.ne1 * params.ne0);
|
||||||
|
let i1 = i / params.ne0;
|
||||||
|
let i0 = i % params.ne0;
|
||||||
|
|
||||||
|
let i_src = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1 + i0;
|
||||||
|
let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;
|
||||||
|
|
||||||
|
store_scale(src[i_src] * params.scale + params.bias, i_dst);
|
||||||
|
}
|
||||||
|
#end(SHADER)
|
||||||
|
|
@ -2733,23 +2733,30 @@ struct test_scale : public test_case {
|
||||||
const std::array<int64_t, 4> ne;
|
const std::array<int64_t, 4> ne;
|
||||||
float scale;
|
float scale;
|
||||||
float bias;
|
float bias;
|
||||||
|
bool inplace;
|
||||||
|
|
||||||
std::string vars() override {
|
std::string vars() override {
|
||||||
return VARS_TO_STR4(type, ne, scale, bias);
|
return VARS_TO_STR5(type, ne, scale, bias, inplace);
|
||||||
}
|
}
|
||||||
|
|
||||||
test_scale(ggml_type type = GGML_TYPE_F32,
|
test_scale(ggml_type type = GGML_TYPE_F32,
|
||||||
std::array<int64_t, 4> ne = {10, 10, 10, 10},
|
std::array<int64_t, 4> ne = {10, 10, 10, 10},
|
||||||
float scale = 2.0f,
|
float scale = 2.0f,
|
||||||
float bias = 0.0f)
|
float bias = 0.0f,
|
||||||
: type(type), ne(ne), scale(scale), bias(bias) {}
|
bool inplace = false)
|
||||||
|
: type(type), ne(ne), scale(scale), bias(bias), inplace(inplace) {}
|
||||||
|
|
||||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||||
ggml_set_param(a);
|
ggml_set_param(a);
|
||||||
ggml_set_name(a, "a");
|
ggml_set_name(a, "a");
|
||||||
|
|
||||||
ggml_tensor * out = ggml_scale_bias(ctx, a, scale, bias);
|
ggml_tensor * out;
|
||||||
|
if (inplace) {
|
||||||
|
out = ggml_scale_bias_inplace(ctx, a, scale, bias);
|
||||||
|
} else {
|
||||||
|
out = ggml_scale_bias(ctx, a, scale, bias);
|
||||||
|
}
|
||||||
ggml_set_name(out, "out");
|
ggml_set_name(out, "out");
|
||||||
|
|
||||||
return out;
|
return out;
|
||||||
|
|
@ -2906,16 +2913,18 @@ struct test_rms_norm : public test_case {
|
||||||
const std::array<int64_t, 4> ne;
|
const std::array<int64_t, 4> ne;
|
||||||
const bool v; // whether a is a non-contiguous view
|
const bool v; // whether a is a non-contiguous view
|
||||||
const float eps;
|
const float eps;
|
||||||
|
const bool inplace; // whether to do the operation inplace
|
||||||
|
|
||||||
std::string vars() override {
|
std::string vars() override {
|
||||||
return VARS_TO_STR4(type, ne, v, eps);
|
return VARS_TO_STR5(type, ne, v, eps, inplace);
|
||||||
}
|
}
|
||||||
|
|
||||||
test_rms_norm(ggml_type type = GGML_TYPE_F32,
|
test_rms_norm(ggml_type type = GGML_TYPE_F32,
|
||||||
std::array<int64_t, 4> ne = {64, 5, 4, 3},
|
std::array<int64_t, 4> ne = {64, 5, 4, 3},
|
||||||
bool v = false,
|
bool v = false,
|
||||||
float eps = 1e-6f)
|
float eps = 1e-6f,
|
||||||
: type(type), ne(ne), v(v), eps(eps) {}
|
bool inplace = false)
|
||||||
|
: type(type), ne(ne), v(v), eps(eps), inplace(inplace) {}
|
||||||
|
|
||||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||||
|
|
@ -2927,7 +2936,12 @@ struct test_rms_norm : public test_case {
|
||||||
ggml_set_name(a, "view of a");
|
ggml_set_name(a, "view of a");
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * out = ggml_rms_norm(ctx, a, eps);
|
ggml_tensor * out;
|
||||||
|
if (inplace) {
|
||||||
|
out = ggml_rms_norm_inplace(ctx, a, eps);
|
||||||
|
} else {
|
||||||
|
out = ggml_rms_norm(ctx, a, eps);
|
||||||
|
}
|
||||||
ggml_set_name(out, "out");
|
ggml_set_name(out, "out");
|
||||||
|
|
||||||
return out;
|
return out;
|
||||||
|
|
@ -3832,17 +3846,18 @@ struct test_rope : public test_case {
|
||||||
bool ff;
|
bool ff;
|
||||||
int v; // view (1 : non-contiguous a)
|
int v; // view (1 : non-contiguous a)
|
||||||
bool forward;
|
bool forward;
|
||||||
|
bool inplace;
|
||||||
|
|
||||||
std::string vars() override {
|
std::string vars() override {
|
||||||
// forward can be inferred from the op, does not need to be printed
|
// forward can be inferred from the op, does not need to be printed
|
||||||
return VARS_TO_STR10(type, ne_a, n_dims, mode, n_ctx, fs, ef, af, ff, v);
|
return VARS_TO_STR11(type, ne_a, n_dims, mode, n_ctx, fs, ef, af, ff, v, inplace);
|
||||||
}
|
}
|
||||||
|
|
||||||
test_rope(ggml_type type = GGML_TYPE_F32,
|
test_rope(ggml_type type = GGML_TYPE_F32,
|
||||||
std::array<int64_t, 4> ne_a = {10, 5, 3, 1},
|
std::array<int64_t, 4> ne_a = {10, 5, 3, 1},
|
||||||
int n_dims = 10, int mode = 0, int n_ctx = 512, float fs = 1.0f,
|
int n_dims = 10, int mode = GGML_ROPE_TYPE_NORMAL, int n_ctx = 512, float fs = 1.0f,
|
||||||
float ef = 0.0f, float af = 0.0f, bool ff = false, int v = 0, bool forward = true)
|
float ef = 0.0f, float af = 0.0f, bool ff = false, int v = 0, bool forward = true, bool inplace = false)
|
||||||
: type(type), ne_a(ne_a), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v), forward(forward) {}
|
: type(type), ne_a(ne_a), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v), forward(forward), inplace(inplace) {}
|
||||||
|
|
||||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
ggml_tensor * a;
|
ggml_tensor * a;
|
||||||
|
|
@ -3887,7 +3902,11 @@ struct test_rope : public test_case {
|
||||||
GGML_ASSERT(n_dims/4 > 0);
|
GGML_ASSERT(n_dims/4 > 0);
|
||||||
int rope_sections[4] = {n_dims/4, n_dims/4, 0, 0}; // Vision-RoPE only use first two dimension for image (x, y) coordinate
|
int rope_sections[4] = {n_dims/4, n_dims/4, 0, 0}; // Vision-RoPE only use first two dimension for image (x, y) coordinate
|
||||||
if (forward) {
|
if (forward) {
|
||||||
out = ggml_rope_multi (ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
if (inplace) {
|
||||||
|
out = ggml_rope_multi_inplace(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
||||||
|
} else {
|
||||||
|
out = ggml_rope_multi(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
out = ggml_rope_multi_back(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
out = ggml_rope_multi_back(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
||||||
}
|
}
|
||||||
|
|
@ -3895,14 +3914,22 @@ struct test_rope : public test_case {
|
||||||
GGML_ASSERT(n_dims/3 > 0);
|
GGML_ASSERT(n_dims/3 > 0);
|
||||||
int rope_sections[4] = {n_dims/3, n_dims/3, n_dims/3, 0};
|
int rope_sections[4] = {n_dims/3, n_dims/3, n_dims/3, 0};
|
||||||
if (forward) {
|
if (forward) {
|
||||||
out = ggml_rope_multi (ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
if (inplace) {
|
||||||
|
out = ggml_rope_multi_inplace(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
||||||
|
} else {
|
||||||
|
out = ggml_rope_multi(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
out = ggml_rope_multi_back(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
out = ggml_rope_multi_back(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (forward) {
|
if (forward) {
|
||||||
out = ggml_rope_ext (ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
if (inplace) {
|
||||||
|
out = ggml_rope_ext_inplace(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
||||||
|
} else {
|
||||||
|
out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
out = ggml_rope_ext_back(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
out = ggml_rope_ext_back(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
|
||||||
}
|
}
|
||||||
|
|
@ -6183,9 +6210,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
//add_test_bin_bcast(type, {3, 3, 2560, 1280}, {2, 1, 1, 1});
|
//add_test_bin_bcast(type, {3, 3, 2560, 1280}, {2, 1, 1, 1});
|
||||||
}
|
}
|
||||||
|
|
||||||
// single in-place tests, especially important for WebGPU backend since kernels for in-place vs. not are different
|
// single inplace tests, especially important for WebGPU backend since kernels for inplace vs. not are different
|
||||||
test_cases.emplace_back(new test_bin_bcast(ggml_add_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16));
|
test_cases.emplace_back(new test_bin_bcast(ggml_add_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16));
|
||||||
test_cases.emplace_back(new test_bin_bcast(ggml_mul_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16));
|
test_cases.emplace_back(new test_bin_bcast(ggml_mul_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16));
|
||||||
|
test_cases.emplace_back(new test_bin_bcast(ggml_sub_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16));
|
||||||
|
test_cases.emplace_back(new test_bin_bcast(ggml_div_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16));
|
||||||
|
|
||||||
// fusion
|
// fusion
|
||||||
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {10, 5, 4, 3}, {2, 1, 1, 1}, 2));
|
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {10, 5, 4, 3}, {2, 1, 1, 1}, 2));
|
||||||
|
|
@ -6200,6 +6229,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
test_cases.emplace_back(new test_add1());
|
test_cases.emplace_back(new test_add1());
|
||||||
test_cases.emplace_back(new test_scale());
|
test_cases.emplace_back(new test_scale());
|
||||||
test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f));
|
test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f));
|
||||||
|
test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f, true)); // inplace test
|
||||||
test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {100, 10, 10, 10}, 2.0f, 1.0f));
|
test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {100, 10, 10, 10}, 2.0f, 1.0f));
|
||||||
test_cases.emplace_back(new test_softcap(GGML_TYPE_F32, {10, 10, 10, 10}, 50.0f));
|
test_cases.emplace_back(new test_softcap(GGML_TYPE_F32, {10, 10, 10, 10}, 50.0f));
|
||||||
test_cases.emplace_back(new test_silu_back());
|
test_cases.emplace_back(new test_silu_back());
|
||||||
|
|
@ -6212,6 +6242,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
|
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
|
||||||
test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
|
test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// in-place tests
|
||||||
|
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, false, 1e-6f, true));
|
||||||
|
|
||||||
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) {
|
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) {
|
||||||
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
|
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
|
||||||
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
|
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
|
||||||
|
|
@ -6559,26 +6593,26 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||||
for (bool ff : {false, true}) { // freq_factors
|
for (bool ff : {false, true}) { // freq_factors
|
||||||
for (float v : { 0, 1 }) {
|
for (float v : { 0, 1 }) {
|
||||||
test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 7B
|
test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 7B
|
||||||
|
|
||||||
if (all) {
|
if (all) {
|
||||||
test_cases.emplace_back(new test_rope(type, {128, 40, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 13B
|
test_cases.emplace_back(new test_rope(type, {128, 40, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 13B
|
||||||
test_cases.emplace_back(new test_rope(type, {128, 52, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 30B
|
test_cases.emplace_back(new test_rope(type, {128, 52, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 30B
|
||||||
test_cases.emplace_back(new test_rope(type, {128, 64, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 65B
|
test_cases.emplace_back(new test_rope(type, {128, 64, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 65B
|
||||||
}
|
}
|
||||||
|
|
||||||
if (all) {
|
if (all) {
|
||||||
test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
|
test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
|
||||||
test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
|
test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
|
||||||
test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
|
test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
|
||||||
|
|
||||||
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 0, 512, fs, ef, af, ff, v, fw));
|
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw));
|
||||||
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 0, 512, fs, ef, af, ff, v, fw));
|
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw));
|
||||||
test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 0, 512, fs, ef, af, ff, v, fw));
|
test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw));
|
||||||
|
|
||||||
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 2, 512, fs, ef, af, ff, v, fw)); // neox (stablelm)
|
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (stablelm)
|
||||||
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
|
test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
|
||||||
test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
|
test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (all) {
|
if (all) {
|
||||||
|
|
@ -6589,7 +6623,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)
|
test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)
|
||||||
}
|
}
|
||||||
|
|
||||||
test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
|
test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -6600,6 +6634,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// single inplace test per type/mode/ff
|
||||||
|
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||||
|
for (int mode : {GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX, GGML_ROPE_TYPE_MROPE, GGML_ROPE_TYPE_VISION}) {
|
||||||
|
for (bool ff : {false, true}) {
|
||||||
|
test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, mode, 512, 1.4245f, 0.7465f, 1.4245f, ff, 0, true, true));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (int v : { 0, 1, 2, 3 }) {
|
for (int v : { 0, 1, 2, 3 }) {
|
||||||
for (int dim : { 0, 1, 2, 3, }) {
|
for (int dim : { 0, 1, 2, 3, }) {
|
||||||
test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {11, 12, 13, 14}, 7, dim, v));
|
test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {11, 12, 13, 14}, 7, dim, v));
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue