Refactor all but flashattention, mat mul
This commit is contained in:
parent
8a13bbb11b
commit
a4e9b45306
|
|
@ -62,6 +62,9 @@ struct ggml_webgpu_shader_lib_context {
|
|||
ggml_tensor * dst;
|
||||
|
||||
uint32_t max_wg_size;
|
||||
size_t wg_mem_limit_bytes = 0;
|
||||
bool inplace = 0;
|
||||
bool overlap = 0;
|
||||
};
|
||||
|
||||
struct webgpu_pipeline {
|
||||
|
|
@ -70,6 +73,18 @@ struct webgpu_pipeline {
|
|||
std::shared_ptr<void> context = nullptr;
|
||||
};
|
||||
|
||||
struct ggml_webgpu_generic_shader_decisions {
|
||||
uint32_t wg_size = 0;
|
||||
};
|
||||
|
||||
/** Argsort **/
|
||||
|
||||
struct ggml_webgpu_argsort_shader_lib_context {
|
||||
uint32_t max_wg_size;
|
||||
size_t wg_mem_limit_bytes;
|
||||
int32_t order;
|
||||
};
|
||||
|
||||
/** Set Rows **/
|
||||
|
||||
struct ggml_webgpu_set_rows_pipeline_key {
|
||||
|
|
@ -98,12 +113,124 @@ struct ggml_webgpu_set_rows_shader_decisions {
|
|||
uint32_t wg_size;
|
||||
};
|
||||
|
||||
/** Get Rows **/
|
||||
|
||||
struct ggml_webgpu_get_rows_pipeline_key {
|
||||
ggml_type src_type;
|
||||
int vectorized;
|
||||
|
||||
bool operator==(const ggml_webgpu_get_rows_pipeline_key & other) const {
|
||||
return src_type == other.src_type && vectorized == other.vectorized;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_get_rows_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_get_rows_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.src_type);
|
||||
ggml_webgpu_hash_combine(seed, key.vectorized);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
/** Pad **/
|
||||
struct ggml_webgpu_pad_pipeline_key {
|
||||
bool circular;
|
||||
|
||||
bool operator==(const ggml_webgpu_pad_pipeline_key & other) const { return circular == other.circular; }
|
||||
};
|
||||
|
||||
struct ggml_webgpu_pad_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_pad_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.circular);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
/** Scale **/
|
||||
|
||||
struct ggml_webgpu_scale_pipeline_key {
|
||||
int inplace;
|
||||
|
||||
bool operator==(const ggml_webgpu_scale_pipeline_key & other) const { return inplace == other.inplace; }
|
||||
};
|
||||
|
||||
struct ggml_webgpu_scale_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_scale_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.inplace);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
/** Binary **/
|
||||
|
||||
struct ggml_webgpu_binary_pipeline_key {
|
||||
int type;
|
||||
int op;
|
||||
bool inplace;
|
||||
bool overlap;
|
||||
|
||||
bool operator==(const ggml_webgpu_binary_pipeline_key & other) const {
|
||||
return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_binary_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_binary_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.type);
|
||||
ggml_webgpu_hash_combine(seed, key.op);
|
||||
ggml_webgpu_hash_combine(seed, key.inplace);
|
||||
ggml_webgpu_hash_combine(seed, key.overlap);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
/** Unary **/
|
||||
|
||||
struct ggml_webgpu_unary_pipeline_key {
|
||||
int type;
|
||||
int op;
|
||||
bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella
|
||||
bool inplace;
|
||||
|
||||
bool operator==(const ggml_webgpu_unary_pipeline_key & other) const {
|
||||
return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_unary_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_unary_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.type);
|
||||
ggml_webgpu_hash_combine(seed, key.op);
|
||||
ggml_webgpu_hash_combine(seed, key.is_unary);
|
||||
ggml_webgpu_hash_combine(seed, key.inplace);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
class ggml_webgpu_shader_lib {
|
||||
wgpu::Device device;
|
||||
pre_wgsl::Preprocessor preprocessor;
|
||||
|
||||
std::unordered_map<int, webgpu_pipeline> sum_rows_pipelines; // key is fixed, no variants yet
|
||||
std::unordered_map<int, webgpu_pipeline> argmax_pipelines; // key is vec4
|
||||
std::unordered_map<int, webgpu_pipeline> sum_rows_pipelines; // key is fixed, no variants yet
|
||||
std::unordered_map<int, webgpu_pipeline> argmax_pipelines; // key is vec4
|
||||
std::unordered_map<int, webgpu_pipeline> argsort_pipelines; // key is order
|
||||
std::unordered_map<int, webgpu_pipeline> argsort_merge_pipelines; // key is order
|
||||
std::unordered_map<int, webgpu_pipeline> cumsum_pipelines; // key is fixed, no variants yet
|
||||
std::unordered_map<ggml_webgpu_get_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_get_rows_pipeline_key_hash>
|
||||
get_rows_pipelines; // src_type, vectorized
|
||||
std::unordered_map<ggml_webgpu_unary_pipeline_key, webgpu_pipeline, ggml_webgpu_unary_pipeline_key_hash>
|
||||
unary_pipelines; // type/op/inplace
|
||||
std::unordered_map<ggml_webgpu_scale_pipeline_key, webgpu_pipeline, ggml_webgpu_scale_pipeline_key_hash>
|
||||
scale_pipelines; // inplace
|
||||
std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash>
|
||||
pad_pipelines; // circular/non-circular
|
||||
std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
|
||||
binary_pipelines; // type/op/inplace/overlap
|
||||
|
||||
std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
|
||||
set_rows_pipelines;
|
||||
|
|
@ -191,6 +318,317 @@ class ggml_webgpu_shader_lib {
|
|||
return set_rows_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_cumsum_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
auto it = cumsum_pipelines.find(1);
|
||||
if (it != cumsum_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_cumsum, defines);
|
||||
cumsum_pipelines[1] = ggml_webgpu_create_pipeline(device, processed, "cumsum");
|
||||
return cumsum_pipelines[1];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_argsort_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
bool is_top_k = context.dst->op == GGML_OP_TOP_K;
|
||||
// ascending order is 0, descending order is 1
|
||||
const int32_t order =
|
||||
is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(context.dst, 0);
|
||||
|
||||
auto it = argsort_pipelines.find(order);
|
||||
if (it != argsort_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "argsort";
|
||||
defines.push_back(std::string("ORDER=") + std::to_string(order));
|
||||
variant += std::string("_order") + std::to_string(order);
|
||||
uint32_t wg_size = 1;
|
||||
while (wg_size * 2 <= context.max_wg_size &&
|
||||
wg_size * GGML_WEBGPU_I32_SIZE_BYTES <= context.wg_mem_limit_bytes / 2) {
|
||||
wg_size *= 2;
|
||||
}
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
||||
auto processed = preprocessor.preprocess(wgsl_argsort, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = wg_size;
|
||||
argsort_pipelines[order] = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
argsort_pipelines[order].context = decisions;
|
||||
return argsort_pipelines[order];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_argsort_merge_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
bool is_top_k = context.dst->op == GGML_OP_TOP_K;
|
||||
// ascending order is 0, descending order is 1
|
||||
const int32_t order =
|
||||
is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(context.dst, 0);
|
||||
|
||||
auto it = argsort_merge_pipelines.find(order);
|
||||
if (it != argsort_merge_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "argsort_merge";
|
||||
defines.push_back(std::string("ORDER=") + std::to_string(order));
|
||||
variant += std::string("_order") + std::to_string(order);
|
||||
uint32_t wg_size = std::min(GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE, context.max_wg_size);
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_argsort_merge, defines);
|
||||
argsort_merge_pipelines[order] = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
return argsort_merge_pipelines[order];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_get_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
const bool vectorized = context.src0->type == GGML_TYPE_F32 && context.dst->ne[0] % 4 == 0;
|
||||
ggml_webgpu_get_rows_pipeline_key key = {
|
||||
.src_type = context.src0->type,
|
||||
.vectorized = (int) vectorized,
|
||||
};
|
||||
|
||||
auto it = get_rows_pipelines.find(key);
|
||||
if (it != get_rows_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "get_rows";
|
||||
|
||||
const struct ggml_type_traits * type_traits = ggml_get_type_traits(key.src_type);
|
||||
const char * type_str = type_traits->type_name;
|
||||
|
||||
switch (key.src_type) {
|
||||
case GGML_TYPE_F32:
|
||||
if (key.vectorized) {
|
||||
defines.push_back("F32_VEC");
|
||||
defines.push_back("SRC_TYPE=vec4<f32>");
|
||||
defines.push_back("DST_TYPE=vec4<f32>");
|
||||
defines.push_back("BLOCK_SIZE=4u");
|
||||
} else {
|
||||
defines.push_back("F32");
|
||||
defines.push_back("SRC_TYPE=f32");
|
||||
defines.push_back("DST_TYPE=f32");
|
||||
defines.push_back("BLOCK_SIZE=1u");
|
||||
}
|
||||
variant += "_f32";
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("F16");
|
||||
defines.push_back("SRC_TYPE=f16");
|
||||
defines.push_back("DST_TYPE=f32");
|
||||
defines.push_back("BLOCK_SIZE=1u");
|
||||
variant += "_f16";
|
||||
break;
|
||||
case GGML_TYPE_I32:
|
||||
defines.push_back("I32");
|
||||
defines.push_back("SRC_TYPE=i32");
|
||||
defines.push_back("DST_TYPE=i32");
|
||||
defines.push_back("BLOCK_SIZE=1u");
|
||||
variant += "_i32";
|
||||
break;
|
||||
default: {
|
||||
std::string type_upper = type_str;
|
||||
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
||||
|
||||
defines.push_back("BYTE_HELPERS");
|
||||
defines.push_back(type_upper + "_T");
|
||||
defines.push_back(type_upper);
|
||||
defines.push_back(type_upper + "_SCALE_MIN");
|
||||
defines.push_back(type_upper + "_TABLES");
|
||||
defines.push_back(type_upper + "_GRID");
|
||||
|
||||
variant += "_";
|
||||
variant += type_str;
|
||||
|
||||
defines.push_back(std::string("SRC_TYPE=") + type_str);
|
||||
defines.push_back("DST_TYPE=f32");
|
||||
|
||||
if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) ||
|
||||
key.src_type == GGML_TYPE_IQ4_NL) {
|
||||
defines.push_back("BLOCK_SIZE=32u");
|
||||
} else if (key.src_type >= GGML_TYPE_Q2_K) {
|
||||
defines.push_back("BLOCK_SIZE=256u");
|
||||
} else {
|
||||
defines.push_back("BLOCK_SIZE=1u");
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (key.vectorized) {
|
||||
variant += "_vec";
|
||||
}
|
||||
|
||||
defines.push_back("WG_SIZE=" + std::to_string(context.max_wg_size));
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_get_rows, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
get_rows_pipelines[key] = pipeline;
|
||||
return get_rows_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_scale_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_scale_pipeline_key key = { .inplace = context.inplace };
|
||||
|
||||
auto it = scale_pipelines.find(key);
|
||||
if (it != scale_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "scale";
|
||||
|
||||
if (key.inplace) {
|
||||
defines.push_back("INPLACE");
|
||||
variant += "_inplace";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_scale, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
scale_pipelines[key] = pipeline;
|
||||
return scale_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_pad_pipeline_key key = { .circular = ggml_get_op_params_i32(context.dst, 8) != 0 };
|
||||
|
||||
auto it = pad_pipelines.find(key);
|
||||
if (it != pad_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "pad";
|
||||
|
||||
if (key.circular) {
|
||||
defines.push_back("CIRCULAR");
|
||||
variant += "_circular";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_pad, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
pad_pipelines[key] = pipeline;
|
||||
return pad_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_unary_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
const bool is_unary = context.dst->op == GGML_OP_UNARY;
|
||||
const int op = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op;
|
||||
ggml_webgpu_unary_pipeline_key key = {
|
||||
.type = context.dst->type,
|
||||
.op = op,
|
||||
.is_unary = is_unary,
|
||||
.inplace = context.inplace,
|
||||
};
|
||||
|
||||
auto it = unary_pipelines.find(key);
|
||||
if (it != unary_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = key.is_unary ? ggml_unary_op_name((ggml_unary_op) key.op) :
|
||||
ggml_op_name((ggml_op) key.op);
|
||||
defines.push_back(variant);
|
||||
|
||||
switch (key.type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("TYPE_F32");
|
||||
variant += "_f32";
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("TYPE_F16");
|
||||
variant += "_f16";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported type for unary shader");
|
||||
}
|
||||
|
||||
if (key.inplace) {
|
||||
defines.push_back("INPLACE");
|
||||
variant += "_inplace";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_unary, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
unary_pipelines[key] = pipeline;
|
||||
return unary_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_binary_pipeline_key key = {
|
||||
.type = context.dst->type,
|
||||
.op = context.dst->op,
|
||||
.inplace = context.inplace,
|
||||
.overlap = context.overlap,
|
||||
};
|
||||
|
||||
auto it = binary_pipelines.find(key);
|
||||
if (it != binary_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string op_name = ggml_op_name((ggml_op) key.op);
|
||||
std::string variant = op_name;
|
||||
|
||||
defines.push_back(std::string("OP_") + op_name);
|
||||
|
||||
switch (key.type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("TYPE_F32");
|
||||
variant += "_f32";
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("TYPE_F16");
|
||||
variant += "_f16";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported type for binary shader");
|
||||
}
|
||||
|
||||
if (key.inplace) {
|
||||
defines.push_back("INPLACE");
|
||||
variant += "_inplace";
|
||||
} else if (key.overlap) {
|
||||
defines.push_back("OVERLAP");
|
||||
variant += "_overlap";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_binary, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
binary_pipelines[key] = pipeline;
|
||||
return binary_pipelines[key];
|
||||
}
|
||||
|
||||
private:
|
||||
static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
|
||||
std::string shader_code,
|
||||
|
|
@ -404,443 +842,6 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader(
|
|||
result.decisions = decisions;
|
||||
return result;
|
||||
}
|
||||
|
||||
/** Generic **/
|
||||
|
||||
struct ggml_webgpu_generic_shader_lib_context {
|
||||
int vec4;
|
||||
uint32_t max_wg_size;
|
||||
};
|
||||
|
||||
struct ggml_webgpu_generic_shader_decisions {
|
||||
uint32_t wg_size;
|
||||
};
|
||||
|
||||
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_generic_shader(
|
||||
pre_wgsl::Preprocessor & preprocessor,
|
||||
const char * shader_src,
|
||||
const ggml_webgpu_generic_shader_lib_context & context,
|
||||
const std::string & base_variant) {
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = base_variant;
|
||||
|
||||
if (context.vec4) {
|
||||
defines.push_back("VEC4");
|
||||
variant += "_vec";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
ggml_webgpu_processed_shader result;
|
||||
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
||||
result.variant = variant;
|
||||
return result;
|
||||
}
|
||||
|
||||
/** Pad **/
|
||||
|
||||
struct ggml_webgpu_pad_pipeline_key {
|
||||
bool circular;
|
||||
|
||||
bool operator==(const ggml_webgpu_pad_pipeline_key & other) const { return circular == other.circular; }
|
||||
};
|
||||
|
||||
struct ggml_webgpu_pad_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_pad_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.circular);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_pad_shader_lib_context {
|
||||
ggml_webgpu_pad_pipeline_key key;
|
||||
uint32_t max_wg_size;
|
||||
};
|
||||
|
||||
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_pad_shader(
|
||||
pre_wgsl::Preprocessor & preprocessor,
|
||||
const char * shader_src,
|
||||
const ggml_webgpu_pad_shader_lib_context & context) {
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "pad";
|
||||
|
||||
if (context.key.circular) {
|
||||
defines.push_back("CIRCULAR");
|
||||
variant += "_circular";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
ggml_webgpu_processed_shader result;
|
||||
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
||||
result.variant = variant;
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
result.decisions = decisions;
|
||||
return result;
|
||||
}
|
||||
|
||||
/** Argsort **/
|
||||
|
||||
struct ggml_webgpu_argsort_shader_lib_context {
|
||||
uint32_t max_wg_size;
|
||||
size_t wg_mem_limit_bytes;
|
||||
int32_t order;
|
||||
};
|
||||
|
||||
struct ggml_webgpu_argsort_shader_decisions {
|
||||
uint32_t wg_size = 0;
|
||||
};
|
||||
|
||||
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_shader(
|
||||
pre_wgsl::Preprocessor & preprocessor,
|
||||
const char * shader_src,
|
||||
const ggml_webgpu_argsort_shader_lib_context & context) {
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "argsort";
|
||||
defines.push_back(std::string("ORDER=") + std::to_string(context.order));
|
||||
variant += std::string("_order") + std::to_string(context.order);
|
||||
uint32_t wg_size = 1;
|
||||
while (wg_size * 2 <= context.max_wg_size &&
|
||||
wg_size * GGML_WEBGPU_I32_SIZE_BYTES <= context.wg_mem_limit_bytes / 2) {
|
||||
wg_size *= 2;
|
||||
}
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
||||
ggml_webgpu_processed_shader result;
|
||||
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
||||
result.variant = variant;
|
||||
auto decisions = std::make_shared<ggml_webgpu_argsort_shader_decisions>();
|
||||
decisions->wg_size = wg_size;
|
||||
result.decisions = decisions;
|
||||
return result;
|
||||
}
|
||||
|
||||
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_merge_shader(
|
||||
pre_wgsl::Preprocessor & preprocessor,
|
||||
const char * shader_src,
|
||||
const ggml_webgpu_argsort_shader_lib_context & context) {
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "argsort_merge";
|
||||
defines.push_back(std::string("ORDER=") + std::to_string(context.order));
|
||||
variant += std::string("_order") + std::to_string(context.order);
|
||||
uint32_t wg_size = std::min(GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE, context.max_wg_size);
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
||||
ggml_webgpu_processed_shader result;
|
||||
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
||||
result.variant = variant;
|
||||
auto decisions = std::make_shared<ggml_webgpu_argsort_shader_decisions>();
|
||||
decisions->wg_size = wg_size;
|
||||
result.decisions = decisions;
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_webgpu_unary_pipeline_key {
|
||||
int type;
|
||||
int op;
|
||||
bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella
|
||||
bool inplace;
|
||||
|
||||
bool operator==(const ggml_webgpu_unary_pipeline_key & other) const {
|
||||
return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_unary_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_unary_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.type);
|
||||
ggml_webgpu_hash_combine(seed, key.op);
|
||||
ggml_webgpu_hash_combine(seed, key.is_unary);
|
||||
ggml_webgpu_hash_combine(seed, key.inplace);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_unary_shader_lib_context {
|
||||
ggml_webgpu_unary_pipeline_key key;
|
||||
uint32_t max_wg_size;
|
||||
};
|
||||
|
||||
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_unary_shader(
|
||||
pre_wgsl::Preprocessor & preprocessor,
|
||||
const char * shader_src,
|
||||
const ggml_webgpu_unary_shader_lib_context & context) {
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = context.key.is_unary ? ggml_unary_op_name((ggml_unary_op) context.key.op) :
|
||||
ggml_op_name((ggml_op) context.key.op);
|
||||
// Operation-specific behavior
|
||||
defines.push_back(variant);
|
||||
|
||||
switch (context.key.type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("TYPE_F32");
|
||||
variant += "_f32";
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("TYPE_F16");
|
||||
variant += "_f16";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported type for unary shader");
|
||||
}
|
||||
|
||||
if (context.key.inplace) {
|
||||
defines.push_back("INPLACE");
|
||||
variant += "_inplace";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
ggml_webgpu_processed_shader result;
|
||||
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
||||
result.variant = variant;
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
result.decisions = decisions;
|
||||
return result;
|
||||
}
|
||||
|
||||
/** Scale **/
|
||||
|
||||
struct ggml_webgpu_scale_pipeline_key {
|
||||
int inplace;
|
||||
|
||||
bool operator==(const ggml_webgpu_scale_pipeline_key & other) const { return inplace == other.inplace; }
|
||||
};
|
||||
|
||||
struct ggml_webgpu_scale_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_scale_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.inplace);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
/** Binary **/
|
||||
|
||||
struct ggml_webgpu_binary_pipeline_key {
|
||||
int type;
|
||||
int op;
|
||||
bool inplace;
|
||||
bool overlap;
|
||||
|
||||
bool operator==(const ggml_webgpu_binary_pipeline_key & other) const {
|
||||
return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_binary_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_binary_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.type);
|
||||
ggml_webgpu_hash_combine(seed, key.op);
|
||||
ggml_webgpu_hash_combine(seed, key.inplace);
|
||||
ggml_webgpu_hash_combine(seed, key.overlap);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_scale_shader_lib_context {
|
||||
ggml_webgpu_scale_pipeline_key key;
|
||||
uint32_t max_wg_size;
|
||||
};
|
||||
|
||||
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_scale_shader(
|
||||
pre_wgsl::Preprocessor & preprocessor,
|
||||
const char * shader_src,
|
||||
const ggml_webgpu_scale_shader_lib_context & context) {
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "scale";
|
||||
|
||||
if (context.key.inplace) {
|
||||
defines.push_back("INPLACE");
|
||||
variant += "_inplace";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
ggml_webgpu_processed_shader result;
|
||||
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
||||
result.variant = variant;
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
result.decisions = decisions;
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_webgpu_binary_shader_lib_context {
|
||||
ggml_webgpu_binary_pipeline_key key;
|
||||
uint32_t max_wg_size;
|
||||
};
|
||||
|
||||
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_binary_shader(
|
||||
pre_wgsl::Preprocessor & preprocessor,
|
||||
const char * shader_src,
|
||||
const ggml_webgpu_binary_shader_lib_context & context) {
|
||||
std::vector<std::string> defines;
|
||||
std::string op_name = ggml_op_name((ggml_op) context.key.op);
|
||||
std::string variant = op_name;
|
||||
|
||||
defines.push_back(std::string("OP_") + op_name);
|
||||
|
||||
switch (context.key.type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("TYPE_F32");
|
||||
variant += "_f32";
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("TYPE_F16");
|
||||
variant += "_f16";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported type for binary shader");
|
||||
}
|
||||
|
||||
if (context.key.inplace) {
|
||||
defines.push_back("INPLACE");
|
||||
variant += "_inplace";
|
||||
} else if (context.key.overlap) {
|
||||
defines.push_back("OVERLAP");
|
||||
variant += "_overlap";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
ggml_webgpu_processed_shader result;
|
||||
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
||||
result.variant = variant;
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
result.decisions = decisions;
|
||||
return result;
|
||||
}
|
||||
|
||||
/** get_rows */
|
||||
|
||||
struct ggml_webgpu_get_rows_pipeline_key {
|
||||
ggml_type src_type;
|
||||
int vectorized;
|
||||
|
||||
bool operator==(const ggml_webgpu_get_rows_pipeline_key & other) const {
|
||||
return src_type == other.src_type && vectorized == other.vectorized;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_get_rows_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_get_rows_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.src_type);
|
||||
ggml_webgpu_hash_combine(seed, key.vectorized);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_get_rows_shader_lib_context {
|
||||
ggml_webgpu_get_rows_pipeline_key key;
|
||||
uint32_t max_wg_size;
|
||||
};
|
||||
|
||||
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_get_rows_shader(
|
||||
pre_wgsl::Preprocessor & preprocessor,
|
||||
const char * shader_src,
|
||||
const ggml_webgpu_get_rows_shader_lib_context & context) {
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "get_rows";
|
||||
|
||||
// Determine src type string
|
||||
const char * type_str = nullptr;
|
||||
|
||||
// src type
|
||||
const struct ggml_type_traits * type_traits = ggml_get_type_traits(context.key.src_type);
|
||||
type_str = type_traits->type_name;
|
||||
|
||||
switch (context.key.src_type) {
|
||||
case GGML_TYPE_F32:
|
||||
if (context.key.vectorized) {
|
||||
defines.push_back("F32_VEC");
|
||||
defines.push_back("SRC_TYPE=vec4<f32>");
|
||||
defines.push_back("DST_TYPE=vec4<f32>");
|
||||
defines.push_back("BLOCK_SIZE=4u");
|
||||
} else {
|
||||
defines.push_back("F32");
|
||||
defines.push_back("SRC_TYPE=f32");
|
||||
defines.push_back("DST_TYPE=f32");
|
||||
defines.push_back("BLOCK_SIZE=1u");
|
||||
}
|
||||
variant += "_f32";
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("F16");
|
||||
defines.push_back("SRC_TYPE=f16");
|
||||
defines.push_back("DST_TYPE=f32");
|
||||
defines.push_back("BLOCK_SIZE=1u");
|
||||
variant += "_f16";
|
||||
break;
|
||||
case GGML_TYPE_I32:
|
||||
defines.push_back("I32");
|
||||
defines.push_back("SRC_TYPE=i32");
|
||||
defines.push_back("DST_TYPE=i32");
|
||||
defines.push_back("BLOCK_SIZE=1u");
|
||||
variant += "_i32";
|
||||
break;
|
||||
default:
|
||||
// convert name to upper case for other defines
|
||||
std::string type_upper = type_str;
|
||||
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
||||
|
||||
// push back defines for quantized types
|
||||
defines.push_back("BYTE_HELPERS");
|
||||
defines.push_back(type_upper + "_T");
|
||||
defines.push_back(type_upper);
|
||||
|
||||
// for q4_k and q5_k
|
||||
defines.push_back(type_upper + "_SCALE_MIN");
|
||||
|
||||
// defines for i-quants
|
||||
defines.push_back(type_upper + "_TABLES");
|
||||
defines.push_back(type_upper + "_GRID");
|
||||
|
||||
// add variant
|
||||
variant += "_";
|
||||
variant += type_str;
|
||||
|
||||
// add define for quantized src0 type
|
||||
defines.push_back(std::string("SRC_TYPE=") + type_str);
|
||||
defines.push_back("DST_TYPE=f32");
|
||||
break;
|
||||
}
|
||||
|
||||
// determine block_size for quantized types
|
||||
if (context.key.src_type == GGML_TYPE_I32) {
|
||||
defines.push_back("BLOCK_SIZE=1u");
|
||||
} else if ((context.key.src_type >= GGML_TYPE_Q4_0 && context.key.src_type <= GGML_TYPE_Q8_1) ||
|
||||
context.key.src_type == GGML_TYPE_IQ4_NL) {
|
||||
// Non-K quants use 32
|
||||
defines.push_back("BLOCK_SIZE=32u");
|
||||
} else if (context.key.src_type >= GGML_TYPE_Q2_K) {
|
||||
// K-quants and IQ variants all use 256
|
||||
defines.push_back("BLOCK_SIZE=256u");
|
||||
}
|
||||
|
||||
// Vectorized suffix
|
||||
if (context.key.vectorized) {
|
||||
variant += "_vec";
|
||||
}
|
||||
|
||||
defines.push_back("WORKGROUP_SIZE=" + std::to_string(context.max_wg_size));
|
||||
|
||||
ggml_webgpu_processed_shader result;
|
||||
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
||||
result.variant = variant;
|
||||
|
||||
// Create decisions structure to store workgroup size
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
result.decisions = decisions;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/** Matrix Multiplication **/
|
||||
|
||||
struct ggml_webgpu_mul_mat_pipeline_key {
|
||||
|
|
|
|||
|
|
@ -369,32 +369,13 @@ struct webgpu_context_struct {
|
|||
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
|
||||
flash_attn_pipelines;
|
||||
|
||||
std::unordered_map<int, webgpu_pipeline> argsort_pipelines; // key is order (asc/desc)
|
||||
std::unordered_map<int, webgpu_pipeline> argsort_merge_pipelines; // key is order (asc/desc)
|
||||
std::unordered_map<int, webgpu_pipeline> cumsum_pipelines; // key is fixed, no variants yet
|
||||
|
||||
std::unordered_map<ggml_webgpu_get_rows_pipeline_key,
|
||||
webgpu_pipeline,
|
||||
ggml_webgpu_get_rows_pipeline_key_hash>
|
||||
get_rows_pipelines; // src_type, vectorized
|
||||
|
||||
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
|
||||
|
||||
std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
|
||||
binary_pipelines;
|
||||
|
||||
std::map<int, webgpu_pipeline> rms_norm_pipelines; // inplace
|
||||
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace
|
||||
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> glu_pipelines; // glu_op, type, split
|
||||
|
||||
std::unordered_map<ggml_webgpu_scale_pipeline_key,
|
||||
webgpu_pipeline,
|
||||
ggml_webgpu_scale_pipeline_key_hash>
|
||||
scale_pipelines; // inplace
|
||||
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> soft_max_pipelines; // mask_type, has_sink, inplace
|
||||
std::unordered_map<ggml_webgpu_unary_pipeline_key, webgpu_pipeline, ggml_webgpu_unary_pipeline_key_hash>
|
||||
unary_pipelines;
|
||||
std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash> pad_pipelines;
|
||||
|
||||
size_t memset_bytes_per_thread;
|
||||
};
|
||||
|
|
@ -909,24 +890,13 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g
|
|||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
const bool circular = ggml_get_op_params_i32(dst, 8) != 0;
|
||||
|
||||
ggml_webgpu_pad_pipeline_key pipeline_key = { .circular = circular };
|
||||
ggml_webgpu_pad_shader_lib_context shader_lib_ctx = {
|
||||
.key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src,
|
||||
.dst = dst,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline;
|
||||
auto it = ctx->pad_pipelines.find(pipeline_key);
|
||||
if (it != ctx->pad_pipelines.end()) {
|
||||
pipeline = it->second;
|
||||
} else {
|
||||
ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_pad_shader(ctx->p, wgsl_pad, shader_lib_ctx);
|
||||
pipeline =
|
||||
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
|
||||
pipeline.context = processed.decisions;
|
||||
ctx->pad_pipelines.emplace(pipeline_key, pipeline);
|
||||
}
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_pad_pipeline(shader_lib_ctx);
|
||||
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
|
|
@ -1061,37 +1031,15 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
|
|||
ggml_tensor * src,
|
||||
ggml_tensor * idx,
|
||||
ggml_tensor * dst) {
|
||||
uint32_t vectorized = src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0;
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src,
|
||||
.src1 = nullptr,
|
||||
.dst = dst,
|
||||
.max_wg_size = WEBGPU_MAX_WG_SIZE,
|
||||
};
|
||||
|
||||
// Create pipeline key
|
||||
ggml_webgpu_get_rows_pipeline_key key = { .src_type = src->type, .vectorized = (int) vectorized };
|
||||
|
||||
// Get or create pipeline
|
||||
webgpu_pipeline pipeline;
|
||||
auto it = ctx->get_rows_pipelines.find(key);
|
||||
if (it != ctx->get_rows_pipelines.end()) {
|
||||
pipeline = it->second;
|
||||
} else {
|
||||
// JIT compile the shader
|
||||
const char * shader_src = wgsl_get_rows;
|
||||
|
||||
// Build shader context
|
||||
ggml_webgpu_get_rows_shader_lib_context shader_lib_ctx = {
|
||||
.key = key,
|
||||
.max_wg_size = WEBGPU_MAX_WG_SIZE,
|
||||
};
|
||||
|
||||
ggml_webgpu_processed_shader processed =
|
||||
ggml_webgpu_preprocess_get_rows_shader(ctx->p, shader_src, shader_lib_ctx);
|
||||
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(shader_lib_ctx.max_wg_size);
|
||||
pipeline = ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(),
|
||||
processed.variant.c_str(), constants);
|
||||
pipeline.context = processed.decisions;
|
||||
ctx->get_rows_pipelines.emplace(key, pipeline);
|
||||
}
|
||||
|
||||
auto decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx);
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
|
||||
|
|
@ -1445,27 +1393,16 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
|
|||
static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
bool is_unary = dst->op == GGML_OP_UNARY;
|
||||
bool inplace = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL);
|
||||
int op = is_unary ? (int) ggml_get_unary_op(dst) : dst->op;
|
||||
|
||||
ggml_webgpu_unary_pipeline_key pipeline_key = {
|
||||
.type = dst->type, .op = op, .is_unary = is_unary, .inplace = inplace
|
||||
};
|
||||
ggml_webgpu_unary_shader_lib_context shader_lib_ctx = {
|
||||
.key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src,
|
||||
.src1 = nullptr,
|
||||
.dst = dst,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
.inplace = inplace,
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline;
|
||||
auto it = ctx->unary_pipelines.find(pipeline_key);
|
||||
if (it != ctx->unary_pipelines.end()) {
|
||||
pipeline = it->second;
|
||||
} else {
|
||||
ggml_webgpu_processed_shader processed =
|
||||
ggml_webgpu_preprocess_unary_shader(ctx->p, wgsl_unary, shader_lib_ctx);
|
||||
pipeline =
|
||||
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
|
||||
pipeline.context = processed.decisions;
|
||||
ctx->unary_pipelines.emplace(pipeline_key, pipeline);
|
||||
}
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_unary_pipeline(shader_lib_ctx);
|
||||
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
|
|
@ -1537,30 +1474,18 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
|
|||
ggml_tensor * dst) {
|
||||
binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst);
|
||||
|
||||
ggml_webgpu_binary_pipeline_key pipeline_key = {
|
||||
.type = dst->type,
|
||||
.op = dst->op,
|
||||
.inplace = flags.inplace,
|
||||
.overlap = flags.overlap,
|
||||
};
|
||||
ggml_webgpu_binary_shader_lib_context shader_lib_ctx = {
|
||||
.key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src0,
|
||||
.src1 = src1,
|
||||
.dst = dst,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
.inplace = flags.inplace,
|
||||
.overlap = flags.overlap,
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline;
|
||||
auto it = ctx->binary_pipelines.find(pipeline_key);
|
||||
if (it != ctx->binary_pipelines.end()) {
|
||||
pipeline = it->second;
|
||||
} else {
|
||||
ggml_webgpu_processed_shader processed =
|
||||
ggml_webgpu_preprocess_binary_shader(ctx->p, wgsl_binary, shader_lib_ctx);
|
||||
pipeline =
|
||||
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
|
||||
pipeline.context = processed.decisions;
|
||||
ctx->binary_pipelines.emplace(pipeline_key, pipeline);
|
||||
}
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx);
|
||||
|
||||
auto * decisions = static_cast<ggml_webgpu_argsort_shader_decisions *>(pipeline.context.get());
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
||||
|
||||
|
|
@ -1785,28 +1710,18 @@ static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0,
|
|||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
int inplace = ggml_webgpu_tensor_equal(src, dst);
|
||||
bool inplace = ggml_webgpu_tensor_equal(src, dst);
|
||||
|
||||
ggml_webgpu_scale_pipeline_key key = { .inplace = inplace };
|
||||
|
||||
ggml_webgpu_scale_shader_lib_context shader_lib_ctx = {
|
||||
.key = key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src,
|
||||
.src1 = nullptr,
|
||||
.dst = dst,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
.inplace = inplace,
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline;
|
||||
auto it = ctx->scale_pipelines.find(key);
|
||||
if (it != ctx->scale_pipelines.end()) {
|
||||
pipeline = it->second;
|
||||
} else {
|
||||
ggml_webgpu_processed_shader processed =
|
||||
ggml_webgpu_preprocess_scale_shader(ctx->p, wgsl_scale, shader_lib_ctx);
|
||||
pipeline =
|
||||
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
|
||||
pipeline.context = processed.decisions;
|
||||
ctx->scale_pipelines.emplace(key, pipeline);
|
||||
}
|
||||
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_scale_pipeline(shader_lib_ctx);
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
// params unchanged
|
||||
std::vector<uint32_t> params = {
|
||||
|
|
@ -1841,7 +1756,7 @@ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src,
|
|||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
||||
}
|
||||
|
||||
uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
|
||||
uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
||||
}
|
||||
|
||||
|
|
@ -1946,41 +1861,20 @@ static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src
|
|||
|
||||
static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
bool is_top_k = dst->op == GGML_OP_TOP_K;
|
||||
// ascending order is 0, descending order is 1
|
||||
const int32_t order = is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(dst, 0);
|
||||
|
||||
ggml_webgpu_argsort_shader_lib_context shader_lib_ctx = {
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src,
|
||||
.src1 = nullptr,
|
||||
.dst = dst,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
|
||||
.order = order
|
||||
};
|
||||
|
||||
webgpu_pipeline argsort_pipeline;
|
||||
auto it = ctx->argsort_pipelines.find(order);
|
||||
if (it != ctx->argsort_pipelines.end()) {
|
||||
argsort_pipeline = it->second;
|
||||
} else {
|
||||
ggml_webgpu_processed_shader processed =
|
||||
ggml_webgpu_preprocess_argsort_shader(ctx->p, wgsl_argsort, shader_lib_ctx);
|
||||
argsort_pipeline =
|
||||
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
|
||||
argsort_pipeline.context = processed.decisions;
|
||||
ctx->argsort_pipelines.emplace(order, argsort_pipeline);
|
||||
}
|
||||
auto * argsort_decisions = static_cast<ggml_webgpu_argsort_shader_decisions *>(argsort_pipeline.context.get());
|
||||
webgpu_pipeline argsort_pipeline = ctx->shader_lib->get_argsort_pipeline(shader_lib_ctx);
|
||||
auto * argsort_decisions =
|
||||
static_cast<ggml_webgpu_generic_shader_decisions *>(argsort_pipeline.context.get());
|
||||
|
||||
webgpu_pipeline argsort_merge_pipeline;
|
||||
it = ctx->argsort_merge_pipelines.find(order);
|
||||
if (it != ctx->argsort_merge_pipelines.end()) {
|
||||
argsort_merge_pipeline = it->second;
|
||||
} else {
|
||||
ggml_webgpu_processed_shader processed =
|
||||
ggml_webgpu_preprocess_argsort_merge_shader(ctx->p, wgsl_argsort_merge, shader_lib_ctx);
|
||||
argsort_merge_pipeline =
|
||||
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
|
||||
argsort_merge_pipeline.context = processed.decisions;
|
||||
ctx->argsort_merge_pipelines.emplace(order, argsort_merge_pipeline);
|
||||
}
|
||||
webgpu_pipeline argsort_merge_pipeline = ctx->shader_lib->get_argsort_merge_pipeline(shader_lib_ctx);
|
||||
|
||||
const uint32_t src_ne0 = (uint32_t) src->ne[0];
|
||||
const uint32_t nrows = (uint32_t) ggml_nrows(src);
|
||||
|
|
@ -2139,21 +2033,14 @@ static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src
|
|||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
||||
};
|
||||
|
||||
ggml_webgpu_generic_shader_lib_context shader_lib_ctx = {
|
||||
.vec4 = false,
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src,
|
||||
.src1 = nullptr,
|
||||
.dst = dst,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
};
|
||||
webgpu_pipeline pipeline;
|
||||
auto it = ctx->cumsum_pipelines.find(1);
|
||||
if (it != ctx->cumsum_pipelines.end()) {
|
||||
pipeline = it->second;
|
||||
} else {
|
||||
ggml_webgpu_processed_shader processed =
|
||||
ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_cumsum, shader_lib_ctx, "cumsum");
|
||||
pipeline =
|
||||
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
|
||||
ctx->cumsum_pipelines.emplace(1, pipeline);
|
||||
}
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_cumsum_pipeline(shader_lib_ctx);
|
||||
uint32_t wg_x = ggml_nrows(dst);
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -638,8 +638,7 @@ struct Params {
|
|||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
|
||||
override wg_size: u32;
|
||||
@compute @workgroup_size(wg_size)
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x >= params.n_rows * params.ne2 * params.ne3) {
|
||||
return;
|
||||
|
|
|
|||
Loading…
Reference in New Issue