clean up preprocessing

This commit is contained in:
Reese Levine 2026-02-11 15:46:06 -08:00
parent ae6baf4714
commit d3146697ee
4 changed files with 38 additions and 251 deletions

View File

@ -731,11 +731,11 @@ class ggml_webgpu_shader_lib {
// src1 type (vector)
switch (context.src1->type) {
case GGML_TYPE_F32:
defines.push_back(key.vectorized ? "SRC1_TYPE=vec4<f32>" : "SRC1_TYPE=f32");
defines.push_back("SRC1_INNER_TYPE=f32");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back(key.vectorized ? "SRC1_TYPE=vec4<f16>" : "SRC1_TYPE=f16");
defines.push_back("SRC1_INNER_TYPE=f16");
variant += "_f16";
break;
default:
@ -745,11 +745,11 @@ class ggml_webgpu_shader_lib {
// src0 type (matrix row)
switch (context.src0->type) {
case GGML_TYPE_F32:
defines.push_back(key.vectorized ? "SRC0_TYPE=vec4<f32>" : "SRC0_TYPE=f32");
defines.push_back("SRC0_INNER_TYPE=f32");
defines.push_back("MUL_ACC_FLOAT");
break;
case GGML_TYPE_F16:
defines.push_back(key.vectorized ? "SRC0_TYPE=vec4<f16>" : "SRC0_TYPE=f16");
defines.push_back("SRC0_INNER_TYPE=f16");
defines.push_back("MUL_ACC_FLOAT");
break;
default:
@ -764,22 +764,13 @@ class ggml_webgpu_shader_lib {
defines.push_back("MUL_ACC_" + type_upper);
// For fast path we always dequantize from f16 inside the shader
defines.push_back(key.vectorized ? "SRC0_TYPE=vec4<f16>" : "SRC0_TYPE=f16");
defines.push_back("SRC0_INNER_TYPE=f16");
break;
}
}
// dst type
defines.push_back(key.vectorized ? "DST_TYPE=vec4<f32>" : "DST_TYPE=f32");
// vec/scalar controls
if (key.vectorized) {
defines.push_back("VEC");
defines.push_back("VEC_SIZE=4u");
} else {
defines.push_back("SCALAR");
defines.push_back("VEC_SIZE=1u");
}
// VEC/SCALAR controls
defines.push_back(key.vectorized ? "VEC" : "SCALAR");
uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
uint32_t tile_k = WEBGPU_MUL_MAT_VEC_TILE_K;
@ -824,27 +815,22 @@ class ggml_webgpu_shader_lib {
// src1 type
switch (context.src1->type) {
case GGML_TYPE_F32:
defines.push_back(key.vectorized ? "SRC1_TYPE=vec4<f32>" : "SRC1_TYPE=f32");
defines.push_back(key.vectorized ? "DST_TYPE=vec4<f32>" : "DST_TYPE=f32");
defines.push_back("SRC1_INNER_TYPE=f32");
break;
case GGML_TYPE_F16:
defines.push_back(key.vectorized ? "SRC1_TYPE=vec4<f16>" : "SRC1_TYPE=f16");
defines.push_back(key.vectorized ? "DST_TYPE=vec4<f32>" : "DST_TYPE=f32");
defines.push_back("SRC1_INNER_TYPE=f16");
break;
default:
GGML_ABORT("Unsupported src1 type for mul_mat fast shader");
}
// Shared memory type
defines.push_back(key.vectorized ? "SHMEM_TYPE=vec4<f16>" : "SHMEM_TYPE=f16");
// src0 type
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
const char * src0_name = src0_traits->type_name;
switch (context.src0->type) {
case GGML_TYPE_F32:
defines.push_back(key.vectorized ? "SRC0_TYPE=vec4<f32>" : "SRC0_TYPE=f32");
defines.push_back("SRC0_INNER_TYPE=f32");
defines.push_back("FLOAT");
defines.push_back("MUL_ACC_FLOAT");
defines.push_back("INIT_SRC0_SHMEM_FLOAT");
@ -852,7 +838,7 @@ class ggml_webgpu_shader_lib {
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back(key.vectorized ? "SRC0_TYPE=vec4<f16>" : "SRC0_TYPE=f16");
defines.push_back("SRC0_INNER_TYPE=f16");
defines.push_back("FLOAT");
defines.push_back("MUL_ACC_FLOAT");
defines.push_back("INIT_SRC0_SHMEM_FLOAT");
@ -870,7 +856,7 @@ class ggml_webgpu_shader_lib {
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
// Use f16 inside the shader for quantized types
defines.push_back(key.vectorized ? "SRC0_TYPE=vec4<f16>" : "SRC0_TYPE=f16");
defines.push_back("SRC0_INNER_TYPE=f16");
variant += std::string("_") + src0_name;
break;
@ -879,8 +865,6 @@ class ggml_webgpu_shader_lib {
// VEC/SCALAR controls
defines.push_back(key.vectorized ? "VEC" : "SCALAR");
defines.push_back(key.vectorized ? "VEC_SIZE=4u" : "VEC_SIZE=1u");
defines.push_back(key.vectorized ? "SHMEM_VEC" : "SHMEM_SCALAR");
// Tiles
defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u");
@ -1244,197 +1228,4 @@ class ggml_webgpu_shader_lib {
}
};
// helper function for replacing {{PLACEHOLDERS}}
inline void ggml_webgpu_replace_placeholder(std::string & shader_code,
const std::string & key,
const std::string & value) {
std::string pattern = "{{" + key + "}}";
size_t pos = 0;
while ((pos = shader_code.find(pattern, pos)) != std::string::npos) {
shader_code.replace(pos, pattern.length(), value);
pos += value.length();
}
}
struct ggml_webgpu_processed_shader {
std::string wgsl;
std::string variant;
std::shared_ptr<void> decisions;
};
/** Matrix Multiplication **/
//inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_mul_mat_shader(
// pre_wgsl::Preprocessor & preprocessor,
// const char * shader_src,
// const ggml_webgpu_mul_mat_shader_lib_context & context) {
// std::vector<std::string> defines;
// std::string variant = "mul_mat";
//
// // Determine base variant name based on kernel type
// if (context.key.is_vec) {
// variant = "mul_mat_vec";
// } else if (context.key.use_subgroup_matrix) {
// variant = "mul_mat_subgroup_matrix";
// } else if (context.key.register_tile) {
// variant = "mul_mat_reg_tile";
// }
//
// // Determine src0/src1 type strings
// const char * src0_type_str = nullptr;
//
// bool is_fast_path = context.key.is_vec || context.key.use_subgroup_matrix || context.key.register_tile;
//
// // src1 type
// switch (context.key.src1_type) {
// case GGML_TYPE_F32:
// defines.push_back(context.key.vectorized ? "SRC1_TYPE=vec4<f32>" : "SRC1_TYPE=f32");
// defines.push_back(context.key.vectorized ? "DST_TYPE=vec4<f32>" : "DST_TYPE=f32");
// break;
// case GGML_TYPE_F16:
// defines.push_back(context.key.vectorized ? "SRC1_TYPE=vec4<f16>" : "SRC1_TYPE=f16");
// defines.push_back(context.key.vectorized ? "DST_TYPE=vec4<f32>" : "DST_TYPE=f32");
// break;
// default:
// break;
// }
//
// // same for all types
// defines.push_back(context.key.vectorized ? "SHMEM_TYPE=vec4<f16>" : "SHMEM_TYPE=f16");
//
// // src0 type
// const struct ggml_type_traits * src0_type_traits = ggml_get_type_traits(context.key.src0_type);
// src0_type_str = src0_type_traits->type_name;
//
// // for f32 and f16, account for vectorized src0 types
// switch (context.key.src0_type) {
// case GGML_TYPE_F32:
// src0_type_str = context.key.vectorized ? "vec4<f32>" : "f32";
// defines.push_back(context.key.vectorized ? "SRC0_TYPE=vec4<f32>" : "SRC0_TYPE=f32");
//
// defines.push_back("FLOAT");
// defines.push_back("MUL_ACC_FLOAT");
// defines.push_back("INIT_SRC0_SHMEM_FLOAT");
// defines.push_back("INIT_SRC1_SHMEM_FLOAT");
//
// variant += "_f32";
// break;
//
// case GGML_TYPE_F16:
// src0_type_str = context.key.vectorized ? "vec4<f16>" : "f16";
// defines.push_back(context.key.vectorized ? "SRC0_TYPE=vec4<f16>" : "SRC0_TYPE=f16");
//
// defines.push_back("FLOAT");
// defines.push_back("MUL_ACC_FLOAT");
// defines.push_back("INIT_SRC0_SHMEM_FLOAT");
// defines.push_back("INIT_SRC1_SHMEM_FLOAT");
//
// variant += "_f16";
//
// break;
//
// default:
// // convert name to upper case for other defines
// std::string type_upper = src0_type_str;
// std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
//
// // push back defines for quantized types
// defines.push_back("BYTE_HELPERS");
// defines.push_back(type_upper + "_T");
//
// defines.push_back(type_upper);
// defines.push_back("MUL_ACC_" + type_upper);
// defines.push_back("INIT_SRC0_SHMEM_" + type_upper);
// defines.push_back("INIT_SRC1_SHMEM_FLOAT");
//
// // for q4_k and q5_k
// defines.push_back(type_upper + "_SCALE_MIN");
//
// // defines for i-quants
// defines.push_back(type_upper + "_TABLES");
// defines.push_back(type_upper + "_GRID");
//
// // add variant
// variant += "_";
// variant += src0_type_str;
//
// // add define for non-fast path quantized src0 type-- overwritten if fast
// // path
// defines.push_back(std::string("SRC0_TYPE=") + src0_type_str);
//
// break;
// }
//
// // Add VEC/SCALAR defines
// if (is_fast_path) {
// // if quantized type and using fast path, need to use f16 instead of the
// // quantized type
// if (context.key.src0_type != GGML_TYPE_F32 && context.key.src0_type != GGML_TYPE_F16) {
// src0_type_str = "f16";
// defines.push_back(std::string("SRC0_TYPE=") + src0_type_str);
// }
//
// // all fast paths need VEC vs SCALAR
// defines.push_back(context.key.vectorized ? "VEC" : "SCALAR");
// // add vec_size define too
// defines.push_back(context.key.vectorized ? "VEC_SIZE=4u" : "VEC_SIZE=1u");
//
// // reg_tile and subgroup_matrix need these extra defines
// if (!context.key.is_vec) {
// defines.push_back(context.key.vectorized ? "SHMEM_VEC" : "SHMEM_SCALAR");
// }
// }
//
// // Append src1 type
// variant += std::string("_") + (context.key.src1_type == GGML_TYPE_F32 ? "f32" : "f16");
//
// // printf("DEBUG: After appending src1 type: variant='%s'\n",
// // variant.c_str());
//
// // Vectorized suffix
// if (context.key.vectorized) {
// variant += "_vec";
// }
//
// // Add defines for TILE_M and TILE_N
// defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u");
// defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u");
//
// // Add subgroup matrix defines if using subgroup_matrix
// if (context.key.use_subgroup_matrix) {
// defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u");
// defines.push_back("TILE_K=" + std::to_string(WEBGPU_MUL_MAT_TILE_K) + "u");
// defines.push_back("SUBGROUP_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M) + "u");
// defines.push_back("SUBGROUP_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N) + "u");
// defines.push_back("SUBGROUP_MATRIX_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M) + "u");
// defines.push_back("SUBGROUP_MATRIX_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N) + "u");
// defines.push_back("SUBGROUP_MATRIX_M_SIZE=" + std::to_string(context.sg_mat_m) + "u");
// defines.push_back("SUBGROUP_MATRIX_N_SIZE=" + std::to_string(context.sg_mat_n) + "u");
// defines.push_back("SUBGROUP_MATRIX_K_SIZE=" + std::to_string(context.sg_mat_k) + "u");
// }
//
// ggml_webgpu_processed_shader result;
// result.wgsl = preprocessor.preprocess(shader_src, defines);
// result.variant = variant;
//
// auto decisions = std::make_shared<ggml_webgpu_mul_mat_shader_decisions>();
// decisions->tile_m = WEBGPU_MUL_MAT_TILE_M;
// decisions->tile_n = WEBGPU_MUL_MAT_TILE_N;
// decisions->tile_k = context.key.is_vec ? WEBGPU_MUL_MAT_VEC_TILE_K : WEBGPU_MUL_MAT_TILE_K;
// decisions->wg_size_m = WEBGPU_MUL_MAT_WG_SIZE_M;
// decisions->wg_size_n = WEBGPU_MUL_MAT_WG_SIZE_N;
// decisions->wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
// decisions->outputs_per_wg = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
// decisions->subgroup_m = WEBGPU_MUL_MAT_SUBGROUP_M;
// decisions->subgroup_n = WEBGPU_MUL_MAT_SUBGROUP_N;
// decisions->subgroup_matrix_m = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M;
// decisions->subgroup_matrix_n = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N;
// decisions->mul_mat_wg_size = WEBGPU_MUL_MAT_WG_SIZE;
// decisions->is_vec = context.key.is_vec;
// decisions->use_subgroup_matrix = context.key.use_subgroup_matrix;
// result.decisions = decisions;
//
// return result;
//}
#endif // GGML_WEBGPU_SHADER_LIB_HPP

View File

@ -354,18 +354,9 @@ struct webgpu_context_struct {
std::unique_ptr<ggml_webgpu_shader_lib> shader_lib;
pre_wgsl::Preprocessor p;
webgpu_buf_pool param_buf_pool;
webgpu_buf_pool set_rows_error_buf_pool;
std::unordered_map<ggml_webgpu_mul_mat_pipeline_key,
webgpu_pipeline,
ggml_webgpu_mul_mat_pipeline_key_hash>
mul_mat_pipelines; // src0_type, src1_type, vectorized
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>>
mul_mat_vec_pipelines; // src0_type, src1_type, vectorized
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
std::map<int, webgpu_pipeline> rms_norm_pipelines; // inplace
@ -414,25 +405,6 @@ struct ggml_backend_webgpu_buffer_context {
/* WebGPU object initializations */
// Process a WGSL shader string, replacing tokens of the form {{KEY}} with
// the corresponding values provided in `repls`.
static std::string ggml_webgpu_process_shader_repls(const char * src,
const std::map<std::string, std::string> & repls) {
if (!src) {
return std::string();
}
std::string s = src;
for (const auto & kv : repls) {
std::string token = "{{" + kv.first + "}}";
size_t pos = 0;
while ((pos = s.find(token, pos)) != std::string::npos) {
s.replace(pos, token.length(), kv.second);
pos += kv.second.length();
}
}
return s;
}
static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
const char * shader_code,
const char * label,

View File

@ -1,4 +1,10 @@
#ifdef SHMEM_VEC
#ifdef VEC
#define VEC_SIZE 4
#define SHMEM_TYPE vec4<f16>
#define DST_TYPE vec4<f32>
#define SRC0_TYPE vec4<SRC0_INNER_TYPE>
#define SRC1_TYPE vec4<SRC1_INNER_TYPE>
fn store_shmem(val: vec4<f16>, idx: u32) {
shmem[idx] = val.x;
shmem[idx + 1] = val.y;
@ -7,7 +13,13 @@ fn store_shmem(val: vec4<f16>, idx: u32) {
}
#endif
#ifdef SHMEM_SCALAR
#ifdef SCALAR
#define VEC_SIZE 1
#define SHMEM_TYPE f16
#define DST_TYPE f32
#define SRC0_TYPE SRC0_INNER_TYPE
#define SRC1_TYPE SRC1_INNER_TYPE
fn store_shmem(val: f16, idx: u32) {
shmem[idx] = val;
}

View File

@ -4,6 +4,12 @@ enable f16;
#include "common_decls.tmpl"
#ifdef VEC
#define VEC_SIZE 4
#define DST_TYPE vec4<f32>
#define SRC0_TYPE vec4<SRC0_INNER_TYPE>
#define SRC1_TYPE vec4<SRC1_INNER_TYPE>
fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 {
return f32(dot(SRC1_TYPE(src0_val), src1_val));
}
@ -17,6 +23,12 @@ fn store_val(group_base: u32) -> vec4<f32> {
#endif
#ifdef SCALAR
#define VEC_SIZE 1
#define DST_TYPE f32
#define SRC0_TYPE SRC0_INNER_TYPE
#define SRC1_TYPE SRC1_INNER_TYPE
fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 {
return f32(src0_val) * f32(src1_val);
}