clean up preprocessing
This commit is contained in:
parent
ae6baf4714
commit
d3146697ee
|
|
@ -731,11 +731,11 @@ class ggml_webgpu_shader_lib {
|
|||
// src1 type (vector)
|
||||
switch (context.src1->type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back(key.vectorized ? "SRC1_TYPE=vec4<f32>" : "SRC1_TYPE=f32");
|
||||
defines.push_back("SRC1_INNER_TYPE=f32");
|
||||
variant += "_f32";
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back(key.vectorized ? "SRC1_TYPE=vec4<f16>" : "SRC1_TYPE=f16");
|
||||
defines.push_back("SRC1_INNER_TYPE=f16");
|
||||
variant += "_f16";
|
||||
break;
|
||||
default:
|
||||
|
|
@ -745,11 +745,11 @@ class ggml_webgpu_shader_lib {
|
|||
// src0 type (matrix row)
|
||||
switch (context.src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back(key.vectorized ? "SRC0_TYPE=vec4<f32>" : "SRC0_TYPE=f32");
|
||||
defines.push_back("SRC0_INNER_TYPE=f32");
|
||||
defines.push_back("MUL_ACC_FLOAT");
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back(key.vectorized ? "SRC0_TYPE=vec4<f16>" : "SRC0_TYPE=f16");
|
||||
defines.push_back("SRC0_INNER_TYPE=f16");
|
||||
defines.push_back("MUL_ACC_FLOAT");
|
||||
break;
|
||||
default:
|
||||
|
|
@ -764,22 +764,13 @@ class ggml_webgpu_shader_lib {
|
|||
defines.push_back("MUL_ACC_" + type_upper);
|
||||
|
||||
// For fast path we always dequantize from f16 inside the shader
|
||||
defines.push_back(key.vectorized ? "SRC0_TYPE=vec4<f16>" : "SRC0_TYPE=f16");
|
||||
defines.push_back("SRC0_INNER_TYPE=f16");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// dst type
|
||||
defines.push_back(key.vectorized ? "DST_TYPE=vec4<f32>" : "DST_TYPE=f32");
|
||||
|
||||
// vec/scalar controls
|
||||
if (key.vectorized) {
|
||||
defines.push_back("VEC");
|
||||
defines.push_back("VEC_SIZE=4u");
|
||||
} else {
|
||||
defines.push_back("SCALAR");
|
||||
defines.push_back("VEC_SIZE=1u");
|
||||
}
|
||||
// VEC/SCALAR controls
|
||||
defines.push_back(key.vectorized ? "VEC" : "SCALAR");
|
||||
|
||||
uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
|
||||
uint32_t tile_k = WEBGPU_MUL_MAT_VEC_TILE_K;
|
||||
|
|
@ -824,27 +815,22 @@ class ggml_webgpu_shader_lib {
|
|||
// src1 type
|
||||
switch (context.src1->type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back(key.vectorized ? "SRC1_TYPE=vec4<f32>" : "SRC1_TYPE=f32");
|
||||
defines.push_back(key.vectorized ? "DST_TYPE=vec4<f32>" : "DST_TYPE=f32");
|
||||
defines.push_back("SRC1_INNER_TYPE=f32");
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back(key.vectorized ? "SRC1_TYPE=vec4<f16>" : "SRC1_TYPE=f16");
|
||||
defines.push_back(key.vectorized ? "DST_TYPE=vec4<f32>" : "DST_TYPE=f32");
|
||||
defines.push_back("SRC1_INNER_TYPE=f16");
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported src1 type for mul_mat fast shader");
|
||||
}
|
||||
|
||||
// Shared memory type
|
||||
defines.push_back(key.vectorized ? "SHMEM_TYPE=vec4<f16>" : "SHMEM_TYPE=f16");
|
||||
|
||||
// src0 type
|
||||
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
|
||||
const char * src0_name = src0_traits->type_name;
|
||||
|
||||
switch (context.src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back(key.vectorized ? "SRC0_TYPE=vec4<f32>" : "SRC0_TYPE=f32");
|
||||
defines.push_back("SRC0_INNER_TYPE=f32");
|
||||
defines.push_back("FLOAT");
|
||||
defines.push_back("MUL_ACC_FLOAT");
|
||||
defines.push_back("INIT_SRC0_SHMEM_FLOAT");
|
||||
|
|
@ -852,7 +838,7 @@ class ggml_webgpu_shader_lib {
|
|||
variant += "_f32";
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back(key.vectorized ? "SRC0_TYPE=vec4<f16>" : "SRC0_TYPE=f16");
|
||||
defines.push_back("SRC0_INNER_TYPE=f16");
|
||||
defines.push_back("FLOAT");
|
||||
defines.push_back("MUL_ACC_FLOAT");
|
||||
defines.push_back("INIT_SRC0_SHMEM_FLOAT");
|
||||
|
|
@ -870,7 +856,7 @@ class ggml_webgpu_shader_lib {
|
|||
defines.push_back("INIT_SRC1_SHMEM_FLOAT");
|
||||
|
||||
// Use f16 inside the shader for quantized types
|
||||
defines.push_back(key.vectorized ? "SRC0_TYPE=vec4<f16>" : "SRC0_TYPE=f16");
|
||||
defines.push_back("SRC0_INNER_TYPE=f16");
|
||||
|
||||
variant += std::string("_") + src0_name;
|
||||
break;
|
||||
|
|
@ -879,8 +865,6 @@ class ggml_webgpu_shader_lib {
|
|||
|
||||
// VEC/SCALAR controls
|
||||
defines.push_back(key.vectorized ? "VEC" : "SCALAR");
|
||||
defines.push_back(key.vectorized ? "VEC_SIZE=4u" : "VEC_SIZE=1u");
|
||||
defines.push_back(key.vectorized ? "SHMEM_VEC" : "SHMEM_SCALAR");
|
||||
|
||||
// Tiles
|
||||
defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u");
|
||||
|
|
@ -1244,197 +1228,4 @@ class ggml_webgpu_shader_lib {
|
|||
}
|
||||
};
|
||||
|
||||
// helper function for replacing {{PLACEHOLDERS}}
|
||||
inline void ggml_webgpu_replace_placeholder(std::string & shader_code,
|
||||
const std::string & key,
|
||||
const std::string & value) {
|
||||
std::string pattern = "{{" + key + "}}";
|
||||
size_t pos = 0;
|
||||
while ((pos = shader_code.find(pattern, pos)) != std::string::npos) {
|
||||
shader_code.replace(pos, pattern.length(), value);
|
||||
pos += value.length();
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_webgpu_processed_shader {
|
||||
std::string wgsl;
|
||||
std::string variant;
|
||||
std::shared_ptr<void> decisions;
|
||||
};
|
||||
|
||||
/** Matrix Multiplication **/
|
||||
|
||||
//inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_mul_mat_shader(
|
||||
// pre_wgsl::Preprocessor & preprocessor,
|
||||
// const char * shader_src,
|
||||
// const ggml_webgpu_mul_mat_shader_lib_context & context) {
|
||||
// std::vector<std::string> defines;
|
||||
// std::string variant = "mul_mat";
|
||||
//
|
||||
// // Determine base variant name based on kernel type
|
||||
// if (context.key.is_vec) {
|
||||
// variant = "mul_mat_vec";
|
||||
// } else if (context.key.use_subgroup_matrix) {
|
||||
// variant = "mul_mat_subgroup_matrix";
|
||||
// } else if (context.key.register_tile) {
|
||||
// variant = "mul_mat_reg_tile";
|
||||
// }
|
||||
//
|
||||
// // Determine src0/src1 type strings
|
||||
// const char * src0_type_str = nullptr;
|
||||
//
|
||||
// bool is_fast_path = context.key.is_vec || context.key.use_subgroup_matrix || context.key.register_tile;
|
||||
//
|
||||
// // src1 type
|
||||
// switch (context.key.src1_type) {
|
||||
// case GGML_TYPE_F32:
|
||||
// defines.push_back(context.key.vectorized ? "SRC1_TYPE=vec4<f32>" : "SRC1_TYPE=f32");
|
||||
// defines.push_back(context.key.vectorized ? "DST_TYPE=vec4<f32>" : "DST_TYPE=f32");
|
||||
// break;
|
||||
// case GGML_TYPE_F16:
|
||||
// defines.push_back(context.key.vectorized ? "SRC1_TYPE=vec4<f16>" : "SRC1_TYPE=f16");
|
||||
// defines.push_back(context.key.vectorized ? "DST_TYPE=vec4<f32>" : "DST_TYPE=f32");
|
||||
// break;
|
||||
// default:
|
||||
// break;
|
||||
// }
|
||||
//
|
||||
// // same for all types
|
||||
// defines.push_back(context.key.vectorized ? "SHMEM_TYPE=vec4<f16>" : "SHMEM_TYPE=f16");
|
||||
//
|
||||
// // src0 type
|
||||
// const struct ggml_type_traits * src0_type_traits = ggml_get_type_traits(context.key.src0_type);
|
||||
// src0_type_str = src0_type_traits->type_name;
|
||||
//
|
||||
// // for f32 and f16, account for vectorized src0 types
|
||||
// switch (context.key.src0_type) {
|
||||
// case GGML_TYPE_F32:
|
||||
// src0_type_str = context.key.vectorized ? "vec4<f32>" : "f32";
|
||||
// defines.push_back(context.key.vectorized ? "SRC0_TYPE=vec4<f32>" : "SRC0_TYPE=f32");
|
||||
//
|
||||
// defines.push_back("FLOAT");
|
||||
// defines.push_back("MUL_ACC_FLOAT");
|
||||
// defines.push_back("INIT_SRC0_SHMEM_FLOAT");
|
||||
// defines.push_back("INIT_SRC1_SHMEM_FLOAT");
|
||||
//
|
||||
// variant += "_f32";
|
||||
// break;
|
||||
//
|
||||
// case GGML_TYPE_F16:
|
||||
// src0_type_str = context.key.vectorized ? "vec4<f16>" : "f16";
|
||||
// defines.push_back(context.key.vectorized ? "SRC0_TYPE=vec4<f16>" : "SRC0_TYPE=f16");
|
||||
//
|
||||
// defines.push_back("FLOAT");
|
||||
// defines.push_back("MUL_ACC_FLOAT");
|
||||
// defines.push_back("INIT_SRC0_SHMEM_FLOAT");
|
||||
// defines.push_back("INIT_SRC1_SHMEM_FLOAT");
|
||||
//
|
||||
// variant += "_f16";
|
||||
//
|
||||
// break;
|
||||
//
|
||||
// default:
|
||||
// // convert name to upper case for other defines
|
||||
// std::string type_upper = src0_type_str;
|
||||
// std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
||||
//
|
||||
// // push back defines for quantized types
|
||||
// defines.push_back("BYTE_HELPERS");
|
||||
// defines.push_back(type_upper + "_T");
|
||||
//
|
||||
// defines.push_back(type_upper);
|
||||
// defines.push_back("MUL_ACC_" + type_upper);
|
||||
// defines.push_back("INIT_SRC0_SHMEM_" + type_upper);
|
||||
// defines.push_back("INIT_SRC1_SHMEM_FLOAT");
|
||||
//
|
||||
// // for q4_k and q5_k
|
||||
// defines.push_back(type_upper + "_SCALE_MIN");
|
||||
//
|
||||
// // defines for i-quants
|
||||
// defines.push_back(type_upper + "_TABLES");
|
||||
// defines.push_back(type_upper + "_GRID");
|
||||
//
|
||||
// // add variant
|
||||
// variant += "_";
|
||||
// variant += src0_type_str;
|
||||
//
|
||||
// // add define for non-fast path quantized src0 type-- overwritten if fast
|
||||
// // path
|
||||
// defines.push_back(std::string("SRC0_TYPE=") + src0_type_str);
|
||||
//
|
||||
// break;
|
||||
// }
|
||||
//
|
||||
// // Add VEC/SCALAR defines
|
||||
// if (is_fast_path) {
|
||||
// // if quantized type and using fast path, need to use f16 instead of the
|
||||
// // quantized type
|
||||
// if (context.key.src0_type != GGML_TYPE_F32 && context.key.src0_type != GGML_TYPE_F16) {
|
||||
// src0_type_str = "f16";
|
||||
// defines.push_back(std::string("SRC0_TYPE=") + src0_type_str);
|
||||
// }
|
||||
//
|
||||
// // all fast paths need VEC vs SCALAR
|
||||
// defines.push_back(context.key.vectorized ? "VEC" : "SCALAR");
|
||||
// // add vec_size define too
|
||||
// defines.push_back(context.key.vectorized ? "VEC_SIZE=4u" : "VEC_SIZE=1u");
|
||||
//
|
||||
// // reg_tile and subgroup_matrix need these extra defines
|
||||
// if (!context.key.is_vec) {
|
||||
// defines.push_back(context.key.vectorized ? "SHMEM_VEC" : "SHMEM_SCALAR");
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // Append src1 type
|
||||
// variant += std::string("_") + (context.key.src1_type == GGML_TYPE_F32 ? "f32" : "f16");
|
||||
//
|
||||
// // printf("DEBUG: After appending src1 type: variant='%s'\n",
|
||||
// // variant.c_str());
|
||||
//
|
||||
// // Vectorized suffix
|
||||
// if (context.key.vectorized) {
|
||||
// variant += "_vec";
|
||||
// }
|
||||
//
|
||||
// // Add defines for TILE_M and TILE_N
|
||||
// defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u");
|
||||
// defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u");
|
||||
//
|
||||
// // Add subgroup matrix defines if using subgroup_matrix
|
||||
// if (context.key.use_subgroup_matrix) {
|
||||
// defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u");
|
||||
// defines.push_back("TILE_K=" + std::to_string(WEBGPU_MUL_MAT_TILE_K) + "u");
|
||||
// defines.push_back("SUBGROUP_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M) + "u");
|
||||
// defines.push_back("SUBGROUP_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N) + "u");
|
||||
// defines.push_back("SUBGROUP_MATRIX_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M) + "u");
|
||||
// defines.push_back("SUBGROUP_MATRIX_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N) + "u");
|
||||
// defines.push_back("SUBGROUP_MATRIX_M_SIZE=" + std::to_string(context.sg_mat_m) + "u");
|
||||
// defines.push_back("SUBGROUP_MATRIX_N_SIZE=" + std::to_string(context.sg_mat_n) + "u");
|
||||
// defines.push_back("SUBGROUP_MATRIX_K_SIZE=" + std::to_string(context.sg_mat_k) + "u");
|
||||
// }
|
||||
//
|
||||
// ggml_webgpu_processed_shader result;
|
||||
// result.wgsl = preprocessor.preprocess(shader_src, defines);
|
||||
// result.variant = variant;
|
||||
//
|
||||
// auto decisions = std::make_shared<ggml_webgpu_mul_mat_shader_decisions>();
|
||||
// decisions->tile_m = WEBGPU_MUL_MAT_TILE_M;
|
||||
// decisions->tile_n = WEBGPU_MUL_MAT_TILE_N;
|
||||
// decisions->tile_k = context.key.is_vec ? WEBGPU_MUL_MAT_VEC_TILE_K : WEBGPU_MUL_MAT_TILE_K;
|
||||
// decisions->wg_size_m = WEBGPU_MUL_MAT_WG_SIZE_M;
|
||||
// decisions->wg_size_n = WEBGPU_MUL_MAT_WG_SIZE_N;
|
||||
// decisions->wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
|
||||
// decisions->outputs_per_wg = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
|
||||
// decisions->subgroup_m = WEBGPU_MUL_MAT_SUBGROUP_M;
|
||||
// decisions->subgroup_n = WEBGPU_MUL_MAT_SUBGROUP_N;
|
||||
// decisions->subgroup_matrix_m = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M;
|
||||
// decisions->subgroup_matrix_n = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N;
|
||||
// decisions->mul_mat_wg_size = WEBGPU_MUL_MAT_WG_SIZE;
|
||||
// decisions->is_vec = context.key.is_vec;
|
||||
// decisions->use_subgroup_matrix = context.key.use_subgroup_matrix;
|
||||
// result.decisions = decisions;
|
||||
//
|
||||
// return result;
|
||||
//}
|
||||
|
||||
#endif // GGML_WEBGPU_SHADER_LIB_HPP
|
||||
|
|
|
|||
|
|
@ -354,18 +354,9 @@ struct webgpu_context_struct {
|
|||
|
||||
std::unique_ptr<ggml_webgpu_shader_lib> shader_lib;
|
||||
|
||||
pre_wgsl::Preprocessor p;
|
||||
|
||||
webgpu_buf_pool param_buf_pool;
|
||||
webgpu_buf_pool set_rows_error_buf_pool;
|
||||
|
||||
std::unordered_map<ggml_webgpu_mul_mat_pipeline_key,
|
||||
webgpu_pipeline,
|
||||
ggml_webgpu_mul_mat_pipeline_key_hash>
|
||||
mul_mat_pipelines; // src0_type, src1_type, vectorized
|
||||
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>>
|
||||
mul_mat_vec_pipelines; // src0_type, src1_type, vectorized
|
||||
|
||||
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
|
||||
|
||||
std::map<int, webgpu_pipeline> rms_norm_pipelines; // inplace
|
||||
|
|
@ -414,25 +405,6 @@ struct ggml_backend_webgpu_buffer_context {
|
|||
|
||||
/* WebGPU object initializations */
|
||||
|
||||
// Process a WGSL shader string, replacing tokens of the form {{KEY}} with
|
||||
// the corresponding values provided in `repls`.
|
||||
static std::string ggml_webgpu_process_shader_repls(const char * src,
|
||||
const std::map<std::string, std::string> & repls) {
|
||||
if (!src) {
|
||||
return std::string();
|
||||
}
|
||||
std::string s = src;
|
||||
for (const auto & kv : repls) {
|
||||
std::string token = "{{" + kv.first + "}}";
|
||||
size_t pos = 0;
|
||||
while ((pos = s.find(token, pos)) != std::string::npos) {
|
||||
s.replace(pos, token.length(), kv.second);
|
||||
pos += kv.second.length();
|
||||
}
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
|
||||
const char * shader_code,
|
||||
const char * label,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,10 @@
|
|||
#ifdef SHMEM_VEC
|
||||
#ifdef VEC
|
||||
#define VEC_SIZE 4
|
||||
#define SHMEM_TYPE vec4<f16>
|
||||
#define DST_TYPE vec4<f32>
|
||||
#define SRC0_TYPE vec4<SRC0_INNER_TYPE>
|
||||
#define SRC1_TYPE vec4<SRC1_INNER_TYPE>
|
||||
|
||||
fn store_shmem(val: vec4<f16>, idx: u32) {
|
||||
shmem[idx] = val.x;
|
||||
shmem[idx + 1] = val.y;
|
||||
|
|
@ -7,7 +13,13 @@ fn store_shmem(val: vec4<f16>, idx: u32) {
|
|||
}
|
||||
#endif
|
||||
|
||||
#ifdef SHMEM_SCALAR
|
||||
#ifdef SCALAR
|
||||
#define VEC_SIZE 1
|
||||
#define SHMEM_TYPE f16
|
||||
#define DST_TYPE f32
|
||||
#define SRC0_TYPE SRC0_INNER_TYPE
|
||||
#define SRC1_TYPE SRC1_INNER_TYPE
|
||||
|
||||
fn store_shmem(val: f16, idx: u32) {
|
||||
shmem[idx] = val;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,12 @@ enable f16;
|
|||
#include "common_decls.tmpl"
|
||||
|
||||
#ifdef VEC
|
||||
|
||||
#define VEC_SIZE 4
|
||||
#define DST_TYPE vec4<f32>
|
||||
#define SRC0_TYPE vec4<SRC0_INNER_TYPE>
|
||||
#define SRC1_TYPE vec4<SRC1_INNER_TYPE>
|
||||
|
||||
fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 {
|
||||
return f32(dot(SRC1_TYPE(src0_val), src1_val));
|
||||
}
|
||||
|
|
@ -17,6 +23,12 @@ fn store_val(group_base: u32) -> vec4<f32> {
|
|||
#endif
|
||||
|
||||
#ifdef SCALAR
|
||||
|
||||
#define VEC_SIZE 1
|
||||
#define DST_TYPE f32
|
||||
#define SRC0_TYPE SRC0_INNER_TYPE
|
||||
#define SRC1_TYPE SRC1_INNER_TYPE
|
||||
|
||||
fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 {
|
||||
return f32(src0_val) * f32(src1_val);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue