Refactor all but flashattention, mat mul

This commit is contained in:
Reese Levine 2026-02-11 10:19:05 -08:00
parent 8a13bbb11b
commit a4e9b45306
3 changed files with 494 additions and 607 deletions

View File

@ -62,6 +62,9 @@ struct ggml_webgpu_shader_lib_context {
ggml_tensor * dst;
uint32_t max_wg_size;
size_t wg_mem_limit_bytes = 0;
bool inplace = 0;
bool overlap = 0;
};
struct webgpu_pipeline {
@ -70,6 +73,18 @@ struct webgpu_pipeline {
std::shared_ptr<void> context = nullptr;
};
struct ggml_webgpu_generic_shader_decisions {
uint32_t wg_size = 0;
};
/** Argsort **/
struct ggml_webgpu_argsort_shader_lib_context {
uint32_t max_wg_size;
size_t wg_mem_limit_bytes;
int32_t order;
};
/** Set Rows **/
struct ggml_webgpu_set_rows_pipeline_key {
@ -98,12 +113,124 @@ struct ggml_webgpu_set_rows_shader_decisions {
uint32_t wg_size;
};
/** Get Rows **/
struct ggml_webgpu_get_rows_pipeline_key {
ggml_type src_type;
int vectorized;
bool operator==(const ggml_webgpu_get_rows_pipeline_key & other) const {
return src_type == other.src_type && vectorized == other.vectorized;
}
};
struct ggml_webgpu_get_rows_pipeline_key_hash {
size_t operator()(const ggml_webgpu_get_rows_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.src_type);
ggml_webgpu_hash_combine(seed, key.vectorized);
return seed;
}
};
/** Pad **/
struct ggml_webgpu_pad_pipeline_key {
bool circular;
bool operator==(const ggml_webgpu_pad_pipeline_key & other) const { return circular == other.circular; }
};
struct ggml_webgpu_pad_pipeline_key_hash {
size_t operator()(const ggml_webgpu_pad_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.circular);
return seed;
}
};
/** Scale **/
struct ggml_webgpu_scale_pipeline_key {
int inplace;
bool operator==(const ggml_webgpu_scale_pipeline_key & other) const { return inplace == other.inplace; }
};
struct ggml_webgpu_scale_pipeline_key_hash {
size_t operator()(const ggml_webgpu_scale_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.inplace);
return seed;
}
};
/** Binary **/
struct ggml_webgpu_binary_pipeline_key {
int type;
int op;
bool inplace;
bool overlap;
bool operator==(const ggml_webgpu_binary_pipeline_key & other) const {
return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap;
}
};
struct ggml_webgpu_binary_pipeline_key_hash {
size_t operator()(const ggml_webgpu_binary_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.type);
ggml_webgpu_hash_combine(seed, key.op);
ggml_webgpu_hash_combine(seed, key.inplace);
ggml_webgpu_hash_combine(seed, key.overlap);
return seed;
}
};
/** Unary **/
struct ggml_webgpu_unary_pipeline_key {
int type;
int op;
bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella
bool inplace;
bool operator==(const ggml_webgpu_unary_pipeline_key & other) const {
return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace;
}
};
struct ggml_webgpu_unary_pipeline_key_hash {
size_t operator()(const ggml_webgpu_unary_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.type);
ggml_webgpu_hash_combine(seed, key.op);
ggml_webgpu_hash_combine(seed, key.is_unary);
ggml_webgpu_hash_combine(seed, key.inplace);
return seed;
}
};
class ggml_webgpu_shader_lib {
wgpu::Device device;
pre_wgsl::Preprocessor preprocessor;
std::unordered_map<int, webgpu_pipeline> sum_rows_pipelines; // key is fixed, no variants yet
std::unordered_map<int, webgpu_pipeline> argmax_pipelines; // key is vec4
std::unordered_map<int, webgpu_pipeline> sum_rows_pipelines; // key is fixed, no variants yet
std::unordered_map<int, webgpu_pipeline> argmax_pipelines; // key is vec4
std::unordered_map<int, webgpu_pipeline> argsort_pipelines; // key is order
std::unordered_map<int, webgpu_pipeline> argsort_merge_pipelines; // key is order
std::unordered_map<int, webgpu_pipeline> cumsum_pipelines; // key is fixed, no variants yet
std::unordered_map<ggml_webgpu_get_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_get_rows_pipeline_key_hash>
get_rows_pipelines; // src_type, vectorized
std::unordered_map<ggml_webgpu_unary_pipeline_key, webgpu_pipeline, ggml_webgpu_unary_pipeline_key_hash>
unary_pipelines; // type/op/inplace
std::unordered_map<ggml_webgpu_scale_pipeline_key, webgpu_pipeline, ggml_webgpu_scale_pipeline_key_hash>
scale_pipelines; // inplace
std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash>
pad_pipelines; // circular/non-circular
std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
binary_pipelines; // type/op/inplace/overlap
std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
set_rows_pipelines;
@ -191,6 +318,317 @@ class ggml_webgpu_shader_lib {
return set_rows_pipelines[key];
}
webgpu_pipeline get_cumsum_pipeline(const ggml_webgpu_shader_lib_context & context) {
auto it = cumsum_pipelines.find(1);
if (it != cumsum_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_cumsum, defines);
cumsum_pipelines[1] = ggml_webgpu_create_pipeline(device, processed, "cumsum");
return cumsum_pipelines[1];
}
webgpu_pipeline get_argsort_pipeline(const ggml_webgpu_shader_lib_context & context) {
bool is_top_k = context.dst->op == GGML_OP_TOP_K;
// ascending order is 0, descending order is 1
const int32_t order =
is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(context.dst, 0);
auto it = argsort_pipelines.find(order);
if (it != argsort_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "argsort";
defines.push_back(std::string("ORDER=") + std::to_string(order));
variant += std::string("_order") + std::to_string(order);
uint32_t wg_size = 1;
while (wg_size * 2 <= context.max_wg_size &&
wg_size * GGML_WEBGPU_I32_SIZE_BYTES <= context.wg_mem_limit_bytes / 2) {
wg_size *= 2;
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
auto processed = preprocessor.preprocess(wgsl_argsort, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = wg_size;
argsort_pipelines[order] = ggml_webgpu_create_pipeline(device, processed, variant);
argsort_pipelines[order].context = decisions;
return argsort_pipelines[order];
}
webgpu_pipeline get_argsort_merge_pipeline(const ggml_webgpu_shader_lib_context & context) {
bool is_top_k = context.dst->op == GGML_OP_TOP_K;
// ascending order is 0, descending order is 1
const int32_t order =
is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(context.dst, 0);
auto it = argsort_merge_pipelines.find(order);
if (it != argsort_merge_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "argsort_merge";
defines.push_back(std::string("ORDER=") + std::to_string(order));
variant += std::string("_order") + std::to_string(order);
uint32_t wg_size = std::min(GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE, context.max_wg_size);
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
auto processed = preprocessor.preprocess(wgsl_argsort_merge, defines);
argsort_merge_pipelines[order] = ggml_webgpu_create_pipeline(device, processed, variant);
return argsort_merge_pipelines[order];
}
webgpu_pipeline get_get_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {
const bool vectorized = context.src0->type == GGML_TYPE_F32 && context.dst->ne[0] % 4 == 0;
ggml_webgpu_get_rows_pipeline_key key = {
.src_type = context.src0->type,
.vectorized = (int) vectorized,
};
auto it = get_rows_pipelines.find(key);
if (it != get_rows_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "get_rows";
const struct ggml_type_traits * type_traits = ggml_get_type_traits(key.src_type);
const char * type_str = type_traits->type_name;
switch (key.src_type) {
case GGML_TYPE_F32:
if (key.vectorized) {
defines.push_back("F32_VEC");
defines.push_back("SRC_TYPE=vec4<f32>");
defines.push_back("DST_TYPE=vec4<f32>");
defines.push_back("BLOCK_SIZE=4u");
} else {
defines.push_back("F32");
defines.push_back("SRC_TYPE=f32");
defines.push_back("DST_TYPE=f32");
defines.push_back("BLOCK_SIZE=1u");
}
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("F16");
defines.push_back("SRC_TYPE=f16");
defines.push_back("DST_TYPE=f32");
defines.push_back("BLOCK_SIZE=1u");
variant += "_f16";
break;
case GGML_TYPE_I32:
defines.push_back("I32");
defines.push_back("SRC_TYPE=i32");
defines.push_back("DST_TYPE=i32");
defines.push_back("BLOCK_SIZE=1u");
variant += "_i32";
break;
default: {
std::string type_upper = type_str;
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
defines.push_back("BYTE_HELPERS");
defines.push_back(type_upper + "_T");
defines.push_back(type_upper);
defines.push_back(type_upper + "_SCALE_MIN");
defines.push_back(type_upper + "_TABLES");
defines.push_back(type_upper + "_GRID");
variant += "_";
variant += type_str;
defines.push_back(std::string("SRC_TYPE=") + type_str);
defines.push_back("DST_TYPE=f32");
if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) ||
key.src_type == GGML_TYPE_IQ4_NL) {
defines.push_back("BLOCK_SIZE=32u");
} else if (key.src_type >= GGML_TYPE_Q2_K) {
defines.push_back("BLOCK_SIZE=256u");
} else {
defines.push_back("BLOCK_SIZE=1u");
}
break;
}
}
if (key.vectorized) {
variant += "_vec";
}
defines.push_back("WG_SIZE=" + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_get_rows, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
get_rows_pipelines[key] = pipeline;
return get_rows_pipelines[key];
}
webgpu_pipeline get_scale_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_scale_pipeline_key key = { .inplace = context.inplace };
auto it = scale_pipelines.find(key);
if (it != scale_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "scale";
if (key.inplace) {
defines.push_back("INPLACE");
variant += "_inplace";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_scale, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
scale_pipelines[key] = pipeline;
return scale_pipelines[key];
}
webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_pad_pipeline_key key = { .circular = ggml_get_op_params_i32(context.dst, 8) != 0 };
auto it = pad_pipelines.find(key);
if (it != pad_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "pad";
if (key.circular) {
defines.push_back("CIRCULAR");
variant += "_circular";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_pad, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
pad_pipelines[key] = pipeline;
return pad_pipelines[key];
}
webgpu_pipeline get_unary_pipeline(const ggml_webgpu_shader_lib_context & context) {
const bool is_unary = context.dst->op == GGML_OP_UNARY;
const int op = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op;
ggml_webgpu_unary_pipeline_key key = {
.type = context.dst->type,
.op = op,
.is_unary = is_unary,
.inplace = context.inplace,
};
auto it = unary_pipelines.find(key);
if (it != unary_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = key.is_unary ? ggml_unary_op_name((ggml_unary_op) key.op) :
ggml_op_name((ggml_op) key.op);
defines.push_back(variant);
switch (key.type) {
case GGML_TYPE_F32:
defines.push_back("TYPE_F32");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("TYPE_F16");
variant += "_f16";
break;
default:
GGML_ABORT("Unsupported type for unary shader");
}
if (key.inplace) {
defines.push_back("INPLACE");
variant += "_inplace";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_unary, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
unary_pipelines[key] = pipeline;
return unary_pipelines[key];
}
webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_binary_pipeline_key key = {
.type = context.dst->type,
.op = context.dst->op,
.inplace = context.inplace,
.overlap = context.overlap,
};
auto it = binary_pipelines.find(key);
if (it != binary_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string op_name = ggml_op_name((ggml_op) key.op);
std::string variant = op_name;
defines.push_back(std::string("OP_") + op_name);
switch (key.type) {
case GGML_TYPE_F32:
defines.push_back("TYPE_F32");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("TYPE_F16");
variant += "_f16";
break;
default:
GGML_ABORT("Unsupported type for binary shader");
}
if (key.inplace) {
defines.push_back("INPLACE");
variant += "_inplace";
} else if (key.overlap) {
defines.push_back("OVERLAP");
variant += "_overlap";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_binary, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
binary_pipelines[key] = pipeline;
return binary_pipelines[key];
}
private:
static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
std::string shader_code,
@ -404,443 +842,6 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader(
result.decisions = decisions;
return result;
}
/** Generic **/
struct ggml_webgpu_generic_shader_lib_context {
int vec4;
uint32_t max_wg_size;
};
struct ggml_webgpu_generic_shader_decisions {
uint32_t wg_size;
};
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_generic_shader(
pre_wgsl::Preprocessor & preprocessor,
const char * shader_src,
const ggml_webgpu_generic_shader_lib_context & context,
const std::string & base_variant) {
std::vector<std::string> defines;
std::string variant = base_variant;
if (context.vec4) {
defines.push_back("VEC4");
variant += "_vec";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
return result;
}
/** Pad **/
struct ggml_webgpu_pad_pipeline_key {
bool circular;
bool operator==(const ggml_webgpu_pad_pipeline_key & other) const { return circular == other.circular; }
};
struct ggml_webgpu_pad_pipeline_key_hash {
size_t operator()(const ggml_webgpu_pad_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.circular);
return seed;
}
};
struct ggml_webgpu_pad_shader_lib_context {
ggml_webgpu_pad_pipeline_key key;
uint32_t max_wg_size;
};
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_pad_shader(
pre_wgsl::Preprocessor & preprocessor,
const char * shader_src,
const ggml_webgpu_pad_shader_lib_context & context) {
std::vector<std::string> defines;
std::string variant = "pad";
if (context.key.circular) {
defines.push_back("CIRCULAR");
variant += "_circular";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
result.decisions = decisions;
return result;
}
/** Argsort **/
struct ggml_webgpu_argsort_shader_lib_context {
uint32_t max_wg_size;
size_t wg_mem_limit_bytes;
int32_t order;
};
struct ggml_webgpu_argsort_shader_decisions {
uint32_t wg_size = 0;
};
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_shader(
pre_wgsl::Preprocessor & preprocessor,
const char * shader_src,
const ggml_webgpu_argsort_shader_lib_context & context) {
std::vector<std::string> defines;
std::string variant = "argsort";
defines.push_back(std::string("ORDER=") + std::to_string(context.order));
variant += std::string("_order") + std::to_string(context.order);
uint32_t wg_size = 1;
while (wg_size * 2 <= context.max_wg_size &&
wg_size * GGML_WEBGPU_I32_SIZE_BYTES <= context.wg_mem_limit_bytes / 2) {
wg_size *= 2;
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
auto decisions = std::make_shared<ggml_webgpu_argsort_shader_decisions>();
decisions->wg_size = wg_size;
result.decisions = decisions;
return result;
}
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_merge_shader(
pre_wgsl::Preprocessor & preprocessor,
const char * shader_src,
const ggml_webgpu_argsort_shader_lib_context & context) {
std::vector<std::string> defines;
std::string variant = "argsort_merge";
defines.push_back(std::string("ORDER=") + std::to_string(context.order));
variant += std::string("_order") + std::to_string(context.order);
uint32_t wg_size = std::min(GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE, context.max_wg_size);
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
auto decisions = std::make_shared<ggml_webgpu_argsort_shader_decisions>();
decisions->wg_size = wg_size;
result.decisions = decisions;
return result;
}
struct ggml_webgpu_unary_pipeline_key {
int type;
int op;
bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella
bool inplace;
bool operator==(const ggml_webgpu_unary_pipeline_key & other) const {
return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace;
}
};
struct ggml_webgpu_unary_pipeline_key_hash {
size_t operator()(const ggml_webgpu_unary_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.type);
ggml_webgpu_hash_combine(seed, key.op);
ggml_webgpu_hash_combine(seed, key.is_unary);
ggml_webgpu_hash_combine(seed, key.inplace);
return seed;
}
};
struct ggml_webgpu_unary_shader_lib_context {
ggml_webgpu_unary_pipeline_key key;
uint32_t max_wg_size;
};
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_unary_shader(
pre_wgsl::Preprocessor & preprocessor,
const char * shader_src,
const ggml_webgpu_unary_shader_lib_context & context) {
std::vector<std::string> defines;
std::string variant = context.key.is_unary ? ggml_unary_op_name((ggml_unary_op) context.key.op) :
ggml_op_name((ggml_op) context.key.op);
// Operation-specific behavior
defines.push_back(variant);
switch (context.key.type) {
case GGML_TYPE_F32:
defines.push_back("TYPE_F32");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("TYPE_F16");
variant += "_f16";
break;
default:
GGML_ABORT("Unsupported type for unary shader");
}
if (context.key.inplace) {
defines.push_back("INPLACE");
variant += "_inplace";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
result.decisions = decisions;
return result;
}
/** Scale **/
struct ggml_webgpu_scale_pipeline_key {
int inplace;
bool operator==(const ggml_webgpu_scale_pipeline_key & other) const { return inplace == other.inplace; }
};
struct ggml_webgpu_scale_pipeline_key_hash {
size_t operator()(const ggml_webgpu_scale_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.inplace);
return seed;
}
};
/** Binary **/
struct ggml_webgpu_binary_pipeline_key {
int type;
int op;
bool inplace;
bool overlap;
bool operator==(const ggml_webgpu_binary_pipeline_key & other) const {
return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap;
}
};
struct ggml_webgpu_binary_pipeline_key_hash {
size_t operator()(const ggml_webgpu_binary_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.type);
ggml_webgpu_hash_combine(seed, key.op);
ggml_webgpu_hash_combine(seed, key.inplace);
ggml_webgpu_hash_combine(seed, key.overlap);
return seed;
}
};
struct ggml_webgpu_scale_shader_lib_context {
ggml_webgpu_scale_pipeline_key key;
uint32_t max_wg_size;
};
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_scale_shader(
pre_wgsl::Preprocessor & preprocessor,
const char * shader_src,
const ggml_webgpu_scale_shader_lib_context & context) {
std::vector<std::string> defines;
std::string variant = "scale";
if (context.key.inplace) {
defines.push_back("INPLACE");
variant += "_inplace";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
result.decisions = decisions;
return result;
}
struct ggml_webgpu_binary_shader_lib_context {
ggml_webgpu_binary_pipeline_key key;
uint32_t max_wg_size;
};
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_binary_shader(
pre_wgsl::Preprocessor & preprocessor,
const char * shader_src,
const ggml_webgpu_binary_shader_lib_context & context) {
std::vector<std::string> defines;
std::string op_name = ggml_op_name((ggml_op) context.key.op);
std::string variant = op_name;
defines.push_back(std::string("OP_") + op_name);
switch (context.key.type) {
case GGML_TYPE_F32:
defines.push_back("TYPE_F32");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("TYPE_F16");
variant += "_f16";
break;
default:
GGML_ABORT("Unsupported type for binary shader");
}
if (context.key.inplace) {
defines.push_back("INPLACE");
variant += "_inplace";
} else if (context.key.overlap) {
defines.push_back("OVERLAP");
variant += "_overlap";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
result.decisions = decisions;
return result;
}
/** get_rows */
struct ggml_webgpu_get_rows_pipeline_key {
ggml_type src_type;
int vectorized;
bool operator==(const ggml_webgpu_get_rows_pipeline_key & other) const {
return src_type == other.src_type && vectorized == other.vectorized;
}
};
struct ggml_webgpu_get_rows_pipeline_key_hash {
size_t operator()(const ggml_webgpu_get_rows_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.src_type);
ggml_webgpu_hash_combine(seed, key.vectorized);
return seed;
}
};
struct ggml_webgpu_get_rows_shader_lib_context {
ggml_webgpu_get_rows_pipeline_key key;
uint32_t max_wg_size;
};
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_get_rows_shader(
pre_wgsl::Preprocessor & preprocessor,
const char * shader_src,
const ggml_webgpu_get_rows_shader_lib_context & context) {
std::vector<std::string> defines;
std::string variant = "get_rows";
// Determine src type string
const char * type_str = nullptr;
// src type
const struct ggml_type_traits * type_traits = ggml_get_type_traits(context.key.src_type);
type_str = type_traits->type_name;
switch (context.key.src_type) {
case GGML_TYPE_F32:
if (context.key.vectorized) {
defines.push_back("F32_VEC");
defines.push_back("SRC_TYPE=vec4<f32>");
defines.push_back("DST_TYPE=vec4<f32>");
defines.push_back("BLOCK_SIZE=4u");
} else {
defines.push_back("F32");
defines.push_back("SRC_TYPE=f32");
defines.push_back("DST_TYPE=f32");
defines.push_back("BLOCK_SIZE=1u");
}
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("F16");
defines.push_back("SRC_TYPE=f16");
defines.push_back("DST_TYPE=f32");
defines.push_back("BLOCK_SIZE=1u");
variant += "_f16";
break;
case GGML_TYPE_I32:
defines.push_back("I32");
defines.push_back("SRC_TYPE=i32");
defines.push_back("DST_TYPE=i32");
defines.push_back("BLOCK_SIZE=1u");
variant += "_i32";
break;
default:
// convert name to upper case for other defines
std::string type_upper = type_str;
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
// push back defines for quantized types
defines.push_back("BYTE_HELPERS");
defines.push_back(type_upper + "_T");
defines.push_back(type_upper);
// for q4_k and q5_k
defines.push_back(type_upper + "_SCALE_MIN");
// defines for i-quants
defines.push_back(type_upper + "_TABLES");
defines.push_back(type_upper + "_GRID");
// add variant
variant += "_";
variant += type_str;
// add define for quantized src0 type
defines.push_back(std::string("SRC_TYPE=") + type_str);
defines.push_back("DST_TYPE=f32");
break;
}
// determine block_size for quantized types
if (context.key.src_type == GGML_TYPE_I32) {
defines.push_back("BLOCK_SIZE=1u");
} else if ((context.key.src_type >= GGML_TYPE_Q4_0 && context.key.src_type <= GGML_TYPE_Q8_1) ||
context.key.src_type == GGML_TYPE_IQ4_NL) {
// Non-K quants use 32
defines.push_back("BLOCK_SIZE=32u");
} else if (context.key.src_type >= GGML_TYPE_Q2_K) {
// K-quants and IQ variants all use 256
defines.push_back("BLOCK_SIZE=256u");
}
// Vectorized suffix
if (context.key.vectorized) {
variant += "_vec";
}
defines.push_back("WORKGROUP_SIZE=" + std::to_string(context.max_wg_size));
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
// Create decisions structure to store workgroup size
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
result.decisions = decisions;
return result;
}
/** Matrix Multiplication **/
struct ggml_webgpu_mul_mat_pipeline_key {

View File

@ -369,32 +369,13 @@ struct webgpu_context_struct {
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
flash_attn_pipelines;
std::unordered_map<int, webgpu_pipeline> argsort_pipelines; // key is order (asc/desc)
std::unordered_map<int, webgpu_pipeline> argsort_merge_pipelines; // key is order (asc/desc)
std::unordered_map<int, webgpu_pipeline> cumsum_pipelines; // key is fixed, no variants yet
std::unordered_map<ggml_webgpu_get_rows_pipeline_key,
webgpu_pipeline,
ggml_webgpu_get_rows_pipeline_key_hash>
get_rows_pipelines; // src_type, vectorized
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
binary_pipelines;
std::map<int, webgpu_pipeline> rms_norm_pipelines; // inplace
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> glu_pipelines; // glu_op, type, split
std::unordered_map<ggml_webgpu_scale_pipeline_key,
webgpu_pipeline,
ggml_webgpu_scale_pipeline_key_hash>
scale_pipelines; // inplace
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> soft_max_pipelines; // mask_type, has_sink, inplace
std::unordered_map<ggml_webgpu_unary_pipeline_key, webgpu_pipeline, ggml_webgpu_unary_pipeline_key_hash>
unary_pipelines;
std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash> pad_pipelines;
size_t memset_bytes_per_thread;
};
@ -909,24 +890,13 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g
}
static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
const bool circular = ggml_get_op_params_i32(dst, 8) != 0;
ggml_webgpu_pad_pipeline_key pipeline_key = { .circular = circular };
ggml_webgpu_pad_shader_lib_context shader_lib_ctx = {
.key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
};
webgpu_pipeline pipeline;
auto it = ctx->pad_pipelines.find(pipeline_key);
if (it != ctx->pad_pipelines.end()) {
pipeline = it->second;
} else {
ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_pad_shader(ctx->p, wgsl_pad, shader_lib_ctx);
pipeline =
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
pipeline.context = processed.decisions;
ctx->pad_pipelines.emplace(pipeline_key, pipeline);
}
webgpu_pipeline pipeline = ctx->shader_lib->get_pad_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
@ -1061,37 +1031,15 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
ggml_tensor * src,
ggml_tensor * idx,
ggml_tensor * dst) {
uint32_t vectorized = src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0;
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src,
.src1 = nullptr,
.dst = dst,
.max_wg_size = WEBGPU_MAX_WG_SIZE,
};
// Create pipeline key
ggml_webgpu_get_rows_pipeline_key key = { .src_type = src->type, .vectorized = (int) vectorized };
// Get or create pipeline
webgpu_pipeline pipeline;
auto it = ctx->get_rows_pipelines.find(key);
if (it != ctx->get_rows_pipelines.end()) {
pipeline = it->second;
} else {
// JIT compile the shader
const char * shader_src = wgsl_get_rows;
// Build shader context
ggml_webgpu_get_rows_shader_lib_context shader_lib_ctx = {
.key = key,
.max_wg_size = WEBGPU_MAX_WG_SIZE,
};
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_get_rows_shader(ctx->p, shader_src, shader_lib_ctx);
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(shader_lib_ctx.max_wg_size);
pipeline = ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(),
processed.variant.c_str(), constants);
pipeline.context = processed.decisions;
ctx->get_rows_pipelines.emplace(key, pipeline);
}
auto decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
webgpu_pipeline pipeline = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
@ -1445,27 +1393,16 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
bool is_unary = dst->op == GGML_OP_UNARY;
bool inplace = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL);
int op = is_unary ? (int) ggml_get_unary_op(dst) : dst->op;
ggml_webgpu_unary_pipeline_key pipeline_key = {
.type = dst->type, .op = op, .is_unary = is_unary, .inplace = inplace
};
ggml_webgpu_unary_shader_lib_context shader_lib_ctx = {
.key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src,
.src1 = nullptr,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
.inplace = inplace,
};
webgpu_pipeline pipeline;
auto it = ctx->unary_pipelines.find(pipeline_key);
if (it != ctx->unary_pipelines.end()) {
pipeline = it->second;
} else {
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_unary_shader(ctx->p, wgsl_unary, shader_lib_ctx);
pipeline =
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
pipeline.context = processed.decisions;
ctx->unary_pipelines.emplace(pipeline_key, pipeline);
}
webgpu_pipeline pipeline = ctx->shader_lib->get_unary_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
@ -1537,30 +1474,18 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
ggml_tensor * dst) {
binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst);
ggml_webgpu_binary_pipeline_key pipeline_key = {
.type = dst->type,
.op = dst->op,
.inplace = flags.inplace,
.overlap = flags.overlap,
};
ggml_webgpu_binary_shader_lib_context shader_lib_ctx = {
.key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src0,
.src1 = src1,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
.inplace = flags.inplace,
.overlap = flags.overlap,
};
webgpu_pipeline pipeline;
auto it = ctx->binary_pipelines.find(pipeline_key);
if (it != ctx->binary_pipelines.end()) {
pipeline = it->second;
} else {
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_binary_shader(ctx->p, wgsl_binary, shader_lib_ctx);
pipeline =
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
pipeline.context = processed.decisions;
ctx->binary_pipelines.emplace(pipeline_key, pipeline);
}
webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_argsort_shader_decisions *>(pipeline.context.get());
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t ne = (uint32_t) ggml_nelements(dst);
@ -1785,28 +1710,18 @@ static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0,
}
static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
int inplace = ggml_webgpu_tensor_equal(src, dst);
bool inplace = ggml_webgpu_tensor_equal(src, dst);
ggml_webgpu_scale_pipeline_key key = { .inplace = inplace };
ggml_webgpu_scale_shader_lib_context shader_lib_ctx = {
.key = key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src,
.src1 = nullptr,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
.inplace = inplace,
};
webgpu_pipeline pipeline;
auto it = ctx->scale_pipelines.find(key);
if (it != ctx->scale_pipelines.end()) {
pipeline = it->second;
} else {
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_scale_shader(ctx->p, wgsl_scale, shader_lib_ctx);
pipeline =
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
pipeline.context = processed.decisions;
ctx->scale_pipelines.emplace(key, pipeline);
}
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
webgpu_pipeline pipeline = ctx->shader_lib->get_scale_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
// params unchanged
std::vector<uint32_t> params = {
@ -1841,7 +1756,7 @@ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src,
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
}
uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
@ -1946,41 +1861,20 @@ static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src
static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
bool is_top_k = dst->op == GGML_OP_TOP_K;
// ascending order is 0, descending order is 1
const int32_t order = is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(dst, 0);
ggml_webgpu_argsort_shader_lib_context shader_lib_ctx = {
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src,
.src1 = nullptr,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
.order = order
};
webgpu_pipeline argsort_pipeline;
auto it = ctx->argsort_pipelines.find(order);
if (it != ctx->argsort_pipelines.end()) {
argsort_pipeline = it->second;
} else {
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_argsort_shader(ctx->p, wgsl_argsort, shader_lib_ctx);
argsort_pipeline =
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
argsort_pipeline.context = processed.decisions;
ctx->argsort_pipelines.emplace(order, argsort_pipeline);
}
auto * argsort_decisions = static_cast<ggml_webgpu_argsort_shader_decisions *>(argsort_pipeline.context.get());
webgpu_pipeline argsort_pipeline = ctx->shader_lib->get_argsort_pipeline(shader_lib_ctx);
auto * argsort_decisions =
static_cast<ggml_webgpu_generic_shader_decisions *>(argsort_pipeline.context.get());
webgpu_pipeline argsort_merge_pipeline;
it = ctx->argsort_merge_pipelines.find(order);
if (it != ctx->argsort_merge_pipelines.end()) {
argsort_merge_pipeline = it->second;
} else {
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_argsort_merge_shader(ctx->p, wgsl_argsort_merge, shader_lib_ctx);
argsort_merge_pipeline =
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
argsort_merge_pipeline.context = processed.decisions;
ctx->argsort_merge_pipelines.emplace(order, argsort_merge_pipeline);
}
webgpu_pipeline argsort_merge_pipeline = ctx->shader_lib->get_argsort_merge_pipeline(shader_lib_ctx);
const uint32_t src_ne0 = (uint32_t) src->ne[0];
const uint32_t nrows = (uint32_t) ggml_nrows(src);
@ -2139,21 +2033,14 @@ static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
};
ggml_webgpu_generic_shader_lib_context shader_lib_ctx = {
.vec4 = false,
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src,
.src1 = nullptr,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
};
webgpu_pipeline pipeline;
auto it = ctx->cumsum_pipelines.find(1);
if (it != ctx->cumsum_pipelines.end()) {
pipeline = it->second;
} else {
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_cumsum, shader_lib_ctx, "cumsum");
pipeline =
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
ctx->cumsum_pipelines.emplace(1, pipeline);
}
webgpu_pipeline pipeline = ctx->shader_lib->get_cumsum_pipeline(shader_lib_ctx);
uint32_t wg_x = ggml_nrows(dst);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}

View File

@ -638,8 +638,7 @@ struct Params {
@group(0) @binding(3)
var<uniform> params: Params;
override wg_size: u32;
@compute @workgroup_size(wg_size)
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.n_rows * params.ne2 * params.ne3) {
return;