Refactor use of wg size entry
This commit is contained in:
parent
b7635c409e
commit
4293531787
|
|
@ -1036,7 +1036,7 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) {
|
|||
}
|
||||
|
||||
// The max workgroup size is a common constant
|
||||
static std::vector<wgpu::ConstantEntry> max_wg_size_entry(webgpu_context & webgpu_ctx) {
|
||||
static std::vector<wgpu::ConstantEntry> ggml_webgpu_max_wg_size_entry(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants(1);
|
||||
constants[0].key = "wg_size";
|
||||
constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
|
||||
|
|
@ -1107,63 +1107,64 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
|
|||
|
||||
static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows",
|
||||
max_wg_size_entry(webgpu_ctx));
|
||||
ggml_webgpu_max_wg_size_entry(webgpu_ctx));
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F32], wgsl_get_rows_f32,
|
||||
"get_rows_f32", max_wg_size_entry(webgpu_ctx));
|
||||
"get_rows_f32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F16], wgsl_get_rows_f16,
|
||||
"get_rows_f16", max_wg_size_entry(webgpu_ctx));
|
||||
"get_rows_f16", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_I32], wgsl_get_rows_i32,
|
||||
"get_rows_i32", max_wg_size_entry(webgpu_ctx));
|
||||
"get_rows_i32", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_0], wgsl_get_rows_q4_0,
|
||||
"get_rows_q4_0", max_wg_size_entry(webgpu_ctx));
|
||||
"get_rows_q4_0", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_1], wgsl_get_rows_q4_1,
|
||||
"get_rows_q4_1", max_wg_size_entry(webgpu_ctx));
|
||||
"get_rows_q4_1", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_0], wgsl_get_rows_q5_0,
|
||||
"get_rows_q5_0", max_wg_size_entry(webgpu_ctx));
|
||||
"get_rows_q5_0", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_1], wgsl_get_rows_q5_1,
|
||||
"get_rows_q5_1", max_wg_size_entry(webgpu_ctx));
|
||||
"get_rows_q5_1", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q8_0], wgsl_get_rows_q8_0,
|
||||
"get_rows_q8_0", max_wg_size_entry(webgpu_ctx));
|
||||
"get_rows_q8_0", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q2_K], wgsl_get_rows_q2_k,
|
||||
"get_rows_q2_k", max_wg_size_entry(webgpu_ctx));
|
||||
"get_rows_q2_k", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q3_K], wgsl_get_rows_q3_k,
|
||||
"get_rows_q3_k", max_wg_size_entry(webgpu_ctx));
|
||||
"get_rows_q3_k", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_K], wgsl_get_rows_q4_k,
|
||||
"get_rows_q4_k", max_wg_size_entry(webgpu_ctx));
|
||||
"get_rows_q4_k", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_K], wgsl_get_rows_q5_k,
|
||||
"get_rows_q5_k", max_wg_size_entry(webgpu_ctx));
|
||||
"get_rows_q5_k", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q6_K], wgsl_get_rows_q6_k,
|
||||
"get_rows_q6_k", max_wg_size_entry(webgpu_ctx));
|
||||
"get_rows_q6_k", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_XXS],
|
||||
wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", max_wg_size_entry(webgpu_ctx));
|
||||
wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_XS],
|
||||
wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", max_wg_size_entry(webgpu_ctx));
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_S],
|
||||
wgsl_get_rows_iq2_s, "get_rows_iq2_s", max_wg_size_entry(webgpu_ctx));
|
||||
wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_S], wgsl_get_rows_iq2_s,
|
||||
"get_rows_iq2_s", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ3_XXS],
|
||||
wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", max_wg_size_entry(webgpu_ctx));
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ3_S],
|
||||
wgsl_get_rows_iq3_s, "get_rows_iq3_s", max_wg_size_entry(webgpu_ctx));
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ1_S],
|
||||
wgsl_get_rows_iq1_s, "get_rows_iq1_s", max_wg_size_entry(webgpu_ctx));
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ1_M],
|
||||
wgsl_get_rows_iq1_m, "get_rows_iq1_m", max_wg_size_entry(webgpu_ctx));
|
||||
wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ3_S], wgsl_get_rows_iq3_s,
|
||||
"get_rows_iq3_s", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ1_S], wgsl_get_rows_iq1_s,
|
||||
"get_rows_iq1_s", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ1_M], wgsl_get_rows_iq1_m,
|
||||
"get_rows_iq1_m", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ4_NL],
|
||||
wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", max_wg_size_entry(webgpu_ctx));
|
||||
wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ4_XS],
|
||||
wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", max_wg_size_entry(webgpu_ctx));
|
||||
wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants);
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy",
|
||||
max_wg_size_entry(webgpu_ctx));
|
||||
ggml_webgpu_max_wg_size_entry(webgpu_ctx));
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = max_wg_size_entry(webgpu_ctx);
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32], wgsl_add_f32, "add_f32",
|
||||
constants);
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16], wgsl_add_f16, "add_f16",
|
||||
|
|
|
|||
Loading…
Reference in New Issue