Use map for shader replacements instead of pair of strings

This commit is contained in:
Reese Levine 2025-11-07 19:06:08 -08:00
parent c6bc12599f
commit 7c2b2ef237
1 changed files with 15 additions and 15 deletions

View File

@ -357,8 +357,8 @@ struct ggml_backend_webgpu_buffer_context {
// Process a WGSL shader string, replacing tokens of the form {{KEY}} with // Process a WGSL shader string, replacing tokens of the form {{KEY}} with
// the corresponding values provided in `repls`. // the corresponding values provided in `repls`.
static std::string ggml_webgpu_process_shader_repls(const char * src, static std::string ggml_webgpu_process_shader_repls(const char * src,
const std::vector<std::pair<std::string, std::string>> & repls) { const std::map<std::string, std::string> & repls) {
if (!src) { if (!src) {
return std::string(); return std::string();
} }
@ -1759,16 +1759,16 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32"); wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32");
if (webgpu_ctx->supports_subgroup_matrix) { if (webgpu_ctx->supports_subgroup_matrix) {
std::vector<std::pair<std::string, std::string>> sg_matrix_repls; std::map<std::string, std::string> sg_matrix_repls;
sg_matrix_repls.emplace_back("WEBGPU_MAX_SUBGROUP_SIZE", std::to_string(webgpu_ctx->subgroup_size)); sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->subgroup_size);
sg_matrix_repls.emplace_back("WEBGPU_TILE_K", std::to_string(WEBGPU_MUL_MAT_TILE_K)); sg_matrix_repls["WEBGPU_TILE_K"] = std::to_string(WEBGPU_MUL_MAT_TILE_K);
sg_matrix_repls.emplace_back("WEBGPU_SUBGROUP_M", std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M)); sg_matrix_repls["WEBGPU_SUBGROUP_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M);
sg_matrix_repls.emplace_back("WEBGPU_SUBGROUP_N", std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N)); sg_matrix_repls["WEBGPU_SUBGROUP_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N);
sg_matrix_repls.emplace_back("WEBGPU_SUBGROUP_MATRIX_M", std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M)); sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M);
sg_matrix_repls.emplace_back("WEBGPU_SUBGROUP_MATRIX_N", std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N)); sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N);
sg_matrix_repls.emplace_back("WEBGPU_SG_MAT_M_SIZE", std::to_string(webgpu_ctx->subgroup_matrix_config.M)); sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.M);
sg_matrix_repls.emplace_back("WEBGPU_SG_MAT_N_SIZE", std::to_string(webgpu_ctx->subgroup_matrix_config.N)); sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.N);
sg_matrix_repls.emplace_back("WEBGPU_SG_MAT_K_SIZE", std::to_string(webgpu_ctx->subgroup_matrix_config.K)); sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.K);
std::string proc_mul_mat_subgroup_matrix_f32_f32 = std::string proc_mul_mat_subgroup_matrix_f32_f32 =
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls); ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls);
@ -1816,9 +1816,9 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
mul_mat_reg_tile_constants[2].key = "WORKGROUP_SIZE_N"; mul_mat_reg_tile_constants[2].key = "WORKGROUP_SIZE_N";
mul_mat_reg_tile_constants[2].value = WEBGPU_MUL_MAT_WG_SIZE_N; mul_mat_reg_tile_constants[2].value = WEBGPU_MUL_MAT_WG_SIZE_N;
std::vector<std::pair<std::string, std::string>> reg_repls; std::map<std::string, std::string> reg_repls;
reg_repls.emplace_back("WEBGPU_TILE_M", std::to_string(WEBGPU_MUL_MAT_TILE_M)); reg_repls["WEBGPU_TILE_M"] = std::to_string(WEBGPU_MUL_MAT_TILE_M);
reg_repls.emplace_back("WEBGPU_TILE_N", std::to_string(WEBGPU_MUL_MAT_TILE_N)); reg_repls["WEBGPU_TILE_N"] = std::to_string(WEBGPU_MUL_MAT_TILE_N);
// Process each reg-tile shader with tile replacements. // Process each reg-tile shader with tile replacements.
// Keep the processed strings in-scope so .c_str() remains valid. // Keep the processed strings in-scope so .c_str() remains valid.