Refactor use of wg size entry

This commit is contained in:
Reese Levine 2025-09-11 16:56:38 -07:00
parent b7635c409e
commit 4293531787
1 changed files with 31 additions and 30 deletions

View File

@ -1036,7 +1036,7 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) {
} }
// The max workgroup size is a common constant // The max workgroup size is a common constant
static std::vector<wgpu::ConstantEntry> max_wg_size_entry(webgpu_context & webgpu_ctx) { static std::vector<wgpu::ConstantEntry> ggml_webgpu_max_wg_size_entry(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants(1); std::vector<wgpu::ConstantEntry> constants(1);
constants[0].key = "wg_size"; constants[0].key = "wg_size";
constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX; constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
@ -1107,63 +1107,64 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) { static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows", ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows",
max_wg_size_entry(webgpu_ctx)); ggml_webgpu_max_wg_size_entry(webgpu_ctx));
} }
static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) { static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F32], wgsl_get_rows_f32, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F32], wgsl_get_rows_f32,
"get_rows_f32", max_wg_size_entry(webgpu_ctx)); "get_rows_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F16], wgsl_get_rows_f16, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F16], wgsl_get_rows_f16,
"get_rows_f16", max_wg_size_entry(webgpu_ctx)); "get_rows_f16", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_I32], wgsl_get_rows_i32, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_I32], wgsl_get_rows_i32,
"get_rows_i32", max_wg_size_entry(webgpu_ctx)); "get_rows_i32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_0], wgsl_get_rows_q4_0, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_0], wgsl_get_rows_q4_0,
"get_rows_q4_0", max_wg_size_entry(webgpu_ctx)); "get_rows_q4_0", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_1], wgsl_get_rows_q4_1, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_1], wgsl_get_rows_q4_1,
"get_rows_q4_1", max_wg_size_entry(webgpu_ctx)); "get_rows_q4_1", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_0], wgsl_get_rows_q5_0, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_0], wgsl_get_rows_q5_0,
"get_rows_q5_0", max_wg_size_entry(webgpu_ctx)); "get_rows_q5_0", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_1], wgsl_get_rows_q5_1, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_1], wgsl_get_rows_q5_1,
"get_rows_q5_1", max_wg_size_entry(webgpu_ctx)); "get_rows_q5_1", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q8_0], wgsl_get_rows_q8_0, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q8_0], wgsl_get_rows_q8_0,
"get_rows_q8_0", max_wg_size_entry(webgpu_ctx)); "get_rows_q8_0", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q2_K], wgsl_get_rows_q2_k, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q2_K], wgsl_get_rows_q2_k,
"get_rows_q2_k", max_wg_size_entry(webgpu_ctx)); "get_rows_q2_k", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q3_K], wgsl_get_rows_q3_k, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q3_K], wgsl_get_rows_q3_k,
"get_rows_q3_k", max_wg_size_entry(webgpu_ctx)); "get_rows_q3_k", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_K], wgsl_get_rows_q4_k, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_K], wgsl_get_rows_q4_k,
"get_rows_q4_k", max_wg_size_entry(webgpu_ctx)); "get_rows_q4_k", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_K], wgsl_get_rows_q5_k, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_K], wgsl_get_rows_q5_k,
"get_rows_q5_k", max_wg_size_entry(webgpu_ctx)); "get_rows_q5_k", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q6_K], wgsl_get_rows_q6_k, ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q6_K], wgsl_get_rows_q6_k,
"get_rows_q6_k", max_wg_size_entry(webgpu_ctx)); "get_rows_q6_k", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_XXS], ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_XXS],
wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", max_wg_size_entry(webgpu_ctx)); wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_XS], ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_XS],
wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", max_wg_size_entry(webgpu_ctx)); wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_S], ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_S], wgsl_get_rows_iq2_s,
wgsl_get_rows_iq2_s, "get_rows_iq2_s", max_wg_size_entry(webgpu_ctx)); "get_rows_iq2_s", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ3_XXS], ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ3_XXS],
wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", max_wg_size_entry(webgpu_ctx)); wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ3_S], ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ3_S], wgsl_get_rows_iq3_s,
wgsl_get_rows_iq3_s, "get_rows_iq3_s", max_wg_size_entry(webgpu_ctx)); "get_rows_iq3_s", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ1_S], ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ1_S], wgsl_get_rows_iq1_s,
wgsl_get_rows_iq1_s, "get_rows_iq1_s", max_wg_size_entry(webgpu_ctx)); "get_rows_iq1_s", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ1_M], ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ1_M], wgsl_get_rows_iq1_m,
wgsl_get_rows_iq1_m, "get_rows_iq1_m", max_wg_size_entry(webgpu_ctx)); "get_rows_iq1_m", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ4_NL], ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ4_NL],
wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", max_wg_size_entry(webgpu_ctx)); wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ4_XS], ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ4_XS],
wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", max_wg_size_entry(webgpu_ctx)); wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants);
} }
static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy", ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy",
max_wg_size_entry(webgpu_ctx)); ggml_webgpu_max_wg_size_entry(webgpu_ctx));
} }
static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) { static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = max_wg_size_entry(webgpu_ctx); std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32], wgsl_add_f32, "add_f32", ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32], wgsl_add_f32, "add_f32",
constants); constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16], wgsl_add_f16, "add_f16", ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16], wgsl_add_f16, "add_f16",