Use map for shader replacements instead of pair of strings

This commit is contained in:
Reese Levine 2025-11-07 19:06:08 -08:00
parent c6bc12599f
commit 7c2b2ef237
1 changed files with 15 additions and 15 deletions

View File

@ -357,8 +357,8 @@ struct ggml_backend_webgpu_buffer_context {
// Process a WGSL shader string, replacing tokens of the form {{KEY}} with
// the corresponding values provided in `repls`.
static std::string ggml_webgpu_process_shader_repls(const char * src,
const std::vector<std::pair<std::string, std::string>> & repls) {
static std::string ggml_webgpu_process_shader_repls(const char * src,
const std::map<std::string, std::string> & repls) {
if (!src) {
return std::string();
}
@ -1759,16 +1759,16 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32");
if (webgpu_ctx->supports_subgroup_matrix) {
std::vector<std::pair<std::string, std::string>> sg_matrix_repls;
sg_matrix_repls.emplace_back("WEBGPU_MAX_SUBGROUP_SIZE", std::to_string(webgpu_ctx->subgroup_size));
sg_matrix_repls.emplace_back("WEBGPU_TILE_K", std::to_string(WEBGPU_MUL_MAT_TILE_K));
sg_matrix_repls.emplace_back("WEBGPU_SUBGROUP_M", std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M));
sg_matrix_repls.emplace_back("WEBGPU_SUBGROUP_N", std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N));
sg_matrix_repls.emplace_back("WEBGPU_SUBGROUP_MATRIX_M", std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M));
sg_matrix_repls.emplace_back("WEBGPU_SUBGROUP_MATRIX_N", std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N));
sg_matrix_repls.emplace_back("WEBGPU_SG_MAT_M_SIZE", std::to_string(webgpu_ctx->subgroup_matrix_config.M));
sg_matrix_repls.emplace_back("WEBGPU_SG_MAT_N_SIZE", std::to_string(webgpu_ctx->subgroup_matrix_config.N));
sg_matrix_repls.emplace_back("WEBGPU_SG_MAT_K_SIZE", std::to_string(webgpu_ctx->subgroup_matrix_config.K));
std::map<std::string, std::string> sg_matrix_repls;
sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->subgroup_size);
sg_matrix_repls["WEBGPU_TILE_K"] = std::to_string(WEBGPU_MUL_MAT_TILE_K);
sg_matrix_repls["WEBGPU_SUBGROUP_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M);
sg_matrix_repls["WEBGPU_SUBGROUP_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N);
sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M);
sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N);
sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.M);
sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.N);
sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.K);
std::string proc_mul_mat_subgroup_matrix_f32_f32 =
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls);
@ -1816,9 +1816,9 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
mul_mat_reg_tile_constants[2].key = "WORKGROUP_SIZE_N";
mul_mat_reg_tile_constants[2].value = WEBGPU_MUL_MAT_WG_SIZE_N;
std::vector<std::pair<std::string, std::string>> reg_repls;
reg_repls.emplace_back("WEBGPU_TILE_M", std::to_string(WEBGPU_MUL_MAT_TILE_M));
reg_repls.emplace_back("WEBGPU_TILE_N", std::to_string(WEBGPU_MUL_MAT_TILE_N));
std::map<std::string, std::string> reg_repls;
reg_repls["WEBGPU_TILE_M"] = std::to_string(WEBGPU_MUL_MAT_TILE_M);
reg_repls["WEBGPU_TILE_N"] = std::to_string(WEBGPU_MUL_MAT_TILE_N);
// Process each reg-tile shader with tile replacements.
// Keep the processed strings in-scope so .c_str() remains valid.