ggml webgpu: ops support for qwen3.5 (SET, TRI_SOLVE, SSM_CONV, GATED_DELTA_NET) + GET_ROWS optimization (#20687)
* Implement l2_norm, set, tri * Add DIAG/SOLVE_TRI * Add SSM_CONV * Better get_rows and gated_delta_net to support qwen3.5 * Clean up, update ops.md * Fix binding_index type for wasm * Fix read write annotations * cleanups
This commit is contained in:
parent
922b90e567
commit
c1258830b2
|
|
@ -47,7 +47,7 @@ Legend:
|
|||
| FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GATED_DELTA_NET | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GATED_DELTA_NET | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| GATED_LINEAR_ATTN | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
|
|
@ -91,7 +91,7 @@ Legend:
|
|||
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SET | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| SET | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||
| SET_ROWS | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
|
|
@ -101,10 +101,10 @@ Legend:
|
|||
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SSM_CONV | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| SSM_CONV | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
|
|
|
|||
8714
docs/ops/WebGPU.csv
8714
docs/ops/WebGPU.csv
File diff suppressed because it is too large
Load Diff
|
|
@ -95,6 +95,11 @@ struct ggml_webgpu_generic_shader_decisions {
|
|||
uint32_t wg_size = 0;
|
||||
};
|
||||
|
||||
struct ggml_webgpu_ssm_conv_shader_decisions {
|
||||
uint32_t block_size;
|
||||
uint32_t tokens_per_wg;
|
||||
};
|
||||
|
||||
/** Argsort **/
|
||||
|
||||
struct ggml_webgpu_argsort_shader_lib_context {
|
||||
|
|
@ -131,6 +136,26 @@ struct ggml_webgpu_set_rows_shader_decisions {
|
|||
uint32_t wg_size;
|
||||
};
|
||||
|
||||
/** Set **/
|
||||
|
||||
struct ggml_webgpu_set_pipeline_key {
|
||||
ggml_type type;
|
||||
bool inplace;
|
||||
|
||||
bool operator==(const ggml_webgpu_set_pipeline_key & other) const {
|
||||
return type == other.type && inplace == other.inplace;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_set_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_set_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.type);
|
||||
ggml_webgpu_hash_combine(seed, key.inplace);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
/** Get Rows **/
|
||||
|
||||
struct ggml_webgpu_get_rows_pipeline_key {
|
||||
|
|
@ -186,6 +211,67 @@ struct ggml_webgpu_pad_pipeline_key_hash {
|
|||
}
|
||||
};
|
||||
|
||||
/** Solve Tri **/
|
||||
struct ggml_webgpu_solve_tri_pipeline_key {
|
||||
int type;
|
||||
int n;
|
||||
int k;
|
||||
|
||||
bool operator==(const ggml_webgpu_solve_tri_pipeline_key & other) const {
|
||||
return type == other.type && n == other.n && k == other.k;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_solve_tri_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_solve_tri_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.type);
|
||||
ggml_webgpu_hash_combine(seed, key.n);
|
||||
ggml_webgpu_hash_combine(seed, key.k);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
/** SSM Conv **/
|
||||
struct ggml_webgpu_ssm_conv_pipeline_key {
|
||||
int type;
|
||||
int vectorized;
|
||||
|
||||
bool operator==(const ggml_webgpu_ssm_conv_pipeline_key & other) const {
|
||||
return type == other.type && vectorized == other.vectorized;
|
||||
}
|
||||
};
|
||||
|
||||
/** Gated Delta Net **/
|
||||
struct ggml_webgpu_gated_delta_net_pipeline_key {
|
||||
int type;
|
||||
int s_v;
|
||||
int kda;
|
||||
|
||||
bool operator==(const ggml_webgpu_gated_delta_net_pipeline_key & other) const {
|
||||
return type == other.type && s_v == other.s_v && kda == other.kda;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_gated_delta_net_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_gated_delta_net_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.type);
|
||||
ggml_webgpu_hash_combine(seed, key.s_v);
|
||||
ggml_webgpu_hash_combine(seed, key.kda);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_ssm_conv_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_ssm_conv_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.type);
|
||||
ggml_webgpu_hash_combine(seed, key.vectorized);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
/** Scale **/
|
||||
|
||||
struct ggml_webgpu_scale_pipeline_key {
|
||||
|
|
@ -466,14 +552,22 @@ class ggml_webgpu_shader_lib {
|
|||
unary_pipelines; // type/op/inplace
|
||||
std::unordered_map<ggml_webgpu_scale_pipeline_key, webgpu_pipeline, ggml_webgpu_scale_pipeline_key_hash>
|
||||
scale_pipelines; // inplace
|
||||
std::unordered_map<ggml_webgpu_solve_tri_pipeline_key, webgpu_pipeline, ggml_webgpu_solve_tri_pipeline_key_hash>
|
||||
solve_tri_pipelines; // type
|
||||
std::unordered_map<ggml_webgpu_ssm_conv_pipeline_key, webgpu_pipeline, ggml_webgpu_ssm_conv_pipeline_key_hash>
|
||||
ssm_conv_pipelines; // type/vectorized
|
||||
std::unordered_map<ggml_webgpu_gated_delta_net_pipeline_key,
|
||||
webgpu_pipeline,
|
||||
ggml_webgpu_gated_delta_net_pipeline_key_hash>
|
||||
gated_delta_net_pipelines; // type/S_v/kda
|
||||
std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash>
|
||||
pad_pipelines; // circular/non-circular
|
||||
pad_pipelines; // circular/non-circular
|
||||
std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
|
||||
binary_pipelines; // type/op/inplace/overlap
|
||||
binary_pipelines; // type/op/inplace/overlap
|
||||
std::unordered_map<ggml_webgpu_concat_pipeline_key, webgpu_pipeline, ggml_webgpu_concat_pipeline_key_hash>
|
||||
concat_pipelines; // type
|
||||
concat_pipelines; // type
|
||||
std::unordered_map<ggml_webgpu_repeat_pipeline_key, webgpu_pipeline, ggml_webgpu_repeat_pipeline_key_hash>
|
||||
repeat_pipelines; // type
|
||||
repeat_pipelines; // type
|
||||
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
|
||||
flash_attn_pipelines;
|
||||
std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,
|
||||
|
|
@ -487,6 +581,7 @@ class ggml_webgpu_shader_lib {
|
|||
|
||||
std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
|
||||
set_rows_pipelines;
|
||||
std::unordered_map<ggml_webgpu_set_pipeline_key, webgpu_pipeline, ggml_webgpu_set_pipeline_key_hash> set_pipelines;
|
||||
|
||||
public:
|
||||
ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
|
||||
|
|
@ -519,11 +614,11 @@ class ggml_webgpu_shader_lib {
|
|||
|
||||
switch (key.op) {
|
||||
case GGML_OP_RMS_NORM:
|
||||
defines.push_back("OP_RMS_NORM");
|
||||
defines.push_back("RMS_NORM");
|
||||
variant = "rms_norm";
|
||||
break;
|
||||
case GGML_OP_L2_NORM:
|
||||
defines.push_back("OP_L2_NORM");
|
||||
defines.push_back("L2_NORM");
|
||||
variant = "l2_norm";
|
||||
break;
|
||||
default:
|
||||
|
|
@ -535,8 +630,9 @@ class ggml_webgpu_shader_lib {
|
|||
variant += "_inplace";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
const uint32_t row_norm_wg_size = 128u;
|
||||
uint32_t wg_size = std::min(context.max_wg_size, row_norm_wg_size);
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
||||
auto processed = preprocessor.preprocess(wgsl_row_norm, defines);
|
||||
row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
return row_norm_pipelines[key];
|
||||
|
|
@ -609,6 +705,46 @@ class ggml_webgpu_shader_lib {
|
|||
return set_rows_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_set_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_set_pipeline_key key = { .type = context.dst->type, .inplace = context.inplace };
|
||||
|
||||
auto it = set_pipelines.find(key);
|
||||
if (it != set_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "set";
|
||||
|
||||
switch (key.type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("TYPE_F32");
|
||||
variant += "_f32";
|
||||
break;
|
||||
case GGML_TYPE_I32:
|
||||
defines.push_back("TYPE_I32");
|
||||
variant += "_i32";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported type for set shader");
|
||||
}
|
||||
|
||||
if (key.inplace) {
|
||||
defines.push_back("INPLACE");
|
||||
variant += "_inplace";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_set, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
set_pipelines[key] = pipeline;
|
||||
return set_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_cumsum_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
auto it = cumsum_pipelines.find(1);
|
||||
if (it != cumsum_pipelines.end()) {
|
||||
|
|
@ -695,6 +831,7 @@ class ggml_webgpu_shader_lib {
|
|||
|
||||
switch (key.src_type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("FLOAT_PARALLEL");
|
||||
if (key.vectorized) {
|
||||
defines.push_back("F32_VEC");
|
||||
defines.push_back("SRC_TYPE=vec4<f32>");
|
||||
|
|
@ -709,6 +846,7 @@ class ggml_webgpu_shader_lib {
|
|||
variant += "_f32";
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("FLOAT_PARALLEL");
|
||||
defines.push_back("F16");
|
||||
defines.push_back("SRC_TYPE=f16");
|
||||
defines.push_back("DST_TYPE=f32");
|
||||
|
|
@ -716,6 +854,7 @@ class ggml_webgpu_shader_lib {
|
|||
variant += "_f16";
|
||||
break;
|
||||
case GGML_TYPE_I32:
|
||||
defines.push_back("FLOAT_PARALLEL");
|
||||
defines.push_back("I32");
|
||||
defines.push_back("SRC_TYPE=i32");
|
||||
defines.push_back("DST_TYPE=i32");
|
||||
|
|
@ -794,6 +933,128 @@ class ggml_webgpu_shader_lib {
|
|||
return scale_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_solve_tri_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_solve_tri_pipeline_key key = {
|
||||
.type = context.dst->type,
|
||||
.n = (int) context.src0->ne[0],
|
||||
.k = (int) context.src1->ne[0],
|
||||
};
|
||||
|
||||
auto it = solve_tri_pipelines.find(key);
|
||||
if (it != solve_tri_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "solve_tri";
|
||||
|
||||
switch (key.type) {
|
||||
case GGML_TYPE_F32:
|
||||
variant += "_f32";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported type for solve_tri shader");
|
||||
}
|
||||
|
||||
const uint32_t wg_size = std::min((uint32_t) key.n, context.max_wg_size);
|
||||
const uint32_t k_tile = wg_size;
|
||||
const uint32_t bytes_per_row = ((uint32_t) key.n + wg_size) * GGML_WEBGPU_F32_SIZE_BYTES;
|
||||
const uint32_t batch_n = (uint32_t) (context.wg_mem_limit_bytes / bytes_per_row);
|
||||
|
||||
defines.push_back(std::string("N=") + std::to_string(key.n));
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
||||
defines.push_back(std::string("K_TILE=") + std::to_string(k_tile));
|
||||
defines.push_back(std::string("BATCH_N=") + std::to_string(batch_n));
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_solve_tri, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = wg_size;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
solve_tri_pipelines[key] = pipeline;
|
||||
return solve_tri_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_ssm_conv_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_ssm_conv_pipeline_key key = {
|
||||
.type = context.dst->type,
|
||||
.vectorized = context.src1->ne[0] == 4,
|
||||
};
|
||||
|
||||
auto it = ssm_conv_pipelines.find(key);
|
||||
if (it != ssm_conv_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "ssm_conv";
|
||||
|
||||
switch (key.type) {
|
||||
case GGML_TYPE_F32:
|
||||
variant += "_f32";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported type for ssm_conv shader");
|
||||
}
|
||||
|
||||
if (key.vectorized) {
|
||||
defines.push_back("VECTORIZED");
|
||||
variant += "_vec4";
|
||||
}
|
||||
|
||||
constexpr uint32_t block_size = 32u;
|
||||
constexpr uint32_t tokens_per_wg = 8u;
|
||||
|
||||
defines.push_back("BLOCK_SIZE=" + std::to_string(block_size) + "u");
|
||||
defines.push_back("TOKENS_PER_WG=" + std::to_string(tokens_per_wg) + "u");
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_ssm_conv, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_ssm_conv_shader_decisions>();
|
||||
decisions->block_size = block_size;
|
||||
decisions->tokens_per_wg = tokens_per_wg;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
ssm_conv_pipelines[key] = pipeline;
|
||||
return ssm_conv_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_gated_delta_net_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_gated_delta_net_pipeline_key key = {
|
||||
.type = context.dst->type,
|
||||
.s_v = (int) context.src2->ne[0],
|
||||
.kda = context.src3->ne[0] == context.src2->ne[0],
|
||||
};
|
||||
|
||||
auto it = gated_delta_net_pipelines.find(key);
|
||||
if (it != gated_delta_net_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "gated_delta_net";
|
||||
|
||||
switch (key.type) {
|
||||
case GGML_TYPE_F32:
|
||||
variant += "_f32";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported type for gated_delta_net shader");
|
||||
}
|
||||
|
||||
if (key.kda) {
|
||||
defines.push_back("KDA");
|
||||
variant += "_kda";
|
||||
}
|
||||
|
||||
defines.push_back("S_V=" + std::to_string(key.s_v) + "u");
|
||||
defines.push_back("WG_SIZE=" + std::to_string(key.s_v) + "u");
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_gated_delta_net, defines);
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
gated_delta_net_pipelines[key] = pipeline;
|
||||
return gated_delta_net_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_pad_pipeline_key key = { .circular = ggml_get_op_params_i32(context.dst, 8) != 0 };
|
||||
|
||||
|
|
|
|||
|
|
@ -880,6 +880,68 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g
|
|||
params, entries, wg_x);
|
||||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_set(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
|
||||
const bool inplace = ggml_webgpu_tensor_equal(src0, dst);
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src0,
|
||||
.src1 = src1,
|
||||
.dst = dst,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
.inplace = inplace,
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_set_pipeline(shader_lib_ctx);
|
||||
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
const uint32_t ne = inplace ? (uint32_t) ggml_nelements(src1) : (uint32_t) ggml_nelements(dst);
|
||||
const uint32_t dst_type_size = (uint32_t) ggml_type_size(dst->type);
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
ne,
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
(uint32_t) (((const int32_t *) dst->op_params)[3] / dst_type_size),
|
||||
|
||||
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
|
||||
|
||||
1u,
|
||||
(uint32_t) (((const int32_t *) dst->op_params)[0] / dst_type_size),
|
||||
(uint32_t) (((const int32_t *) dst->op_params)[1] / dst_type_size),
|
||||
(uint32_t) (((const int32_t *) dst->op_params)[2] / dst_type_size),
|
||||
|
||||
(uint32_t) src1->ne[0],
|
||||
(uint32_t) src1->ne[1],
|
||||
(uint32_t) src1->ne[2],
|
||||
(uint32_t) src1->ne[3],
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries;
|
||||
uint32_t binding_index = 0;
|
||||
if (!inplace) {
|
||||
entries.push_back({ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src0) });
|
||||
binding_index++;
|
||||
}
|
||||
entries.push_back({ .binding = binding_index,
|
||||
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1) });
|
||||
entries.push_back({ .binding = binding_index + 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
||||
|
||||
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
||||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
|
||||
|
|
@ -935,6 +997,208 @@ static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, g
|
|||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
||||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_solve_tri(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst) {
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src0,
|
||||
.src1 = src1,
|
||||
.dst = dst,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_solve_tri_pipeline(shader_lib_ctx);
|
||||
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
|
||||
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
||||
|
||||
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
|
||||
|
||||
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
||||
|
||||
(uint32_t) src1->ne[0],
|
||||
(uint32_t) dst->ne[2],
|
||||
(uint32_t) dst->ne[3],
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
|
||||
{ .binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
||||
};
|
||||
|
||||
const uint32_t wg_x = CEIL_DIV((uint32_t) src1->ne[0], decisions->wg_size);
|
||||
const uint32_t wg_y = (uint32_t) (dst->ne[2] * dst->ne[3]);
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
|
||||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_ssm_conv(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst) {
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src0,
|
||||
.src1 = src1,
|
||||
.dst = dst,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_conv_pipeline(shader_lib_ctx);
|
||||
auto * decisions = static_cast<ggml_webgpu_ssm_conv_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
const uint32_t token_tiles = CEIL_DIV((uint32_t) dst->ne[1], decisions->tokens_per_wg);
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
|
||||
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
|
||||
|
||||
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
|
||||
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
||||
|
||||
(uint32_t) src1->ne[0],
|
||||
(uint32_t) src0->ne[1],
|
||||
(uint32_t) dst->ne[1],
|
||||
(uint32_t) dst->ne[2],
|
||||
token_tiles,
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
|
||||
{ .binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
||||
};
|
||||
|
||||
const uint32_t wg_x = CEIL_DIV((uint32_t) src0->ne[1], decisions->block_size);
|
||||
const uint32_t wg_y = token_tiles * (uint32_t) dst->ne[2];
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
|
||||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_gated_delta_net(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * src2,
|
||||
ggml_tensor * src3,
|
||||
ggml_tensor * src4,
|
||||
ggml_tensor * src5,
|
||||
ggml_tensor * dst) {
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src0,
|
||||
.src1 = src1,
|
||||
.src2 = src2,
|
||||
.src3 = src3,
|
||||
.src4 = src4,
|
||||
.dst = dst,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_gated_delta_net_pipeline(shader_lib_ctx);
|
||||
|
||||
const uint32_t s_v = (uint32_t) src2->ne[0];
|
||||
const uint32_t h = (uint32_t) src2->ne[1];
|
||||
const uint32_t n_tokens = (uint32_t) src2->ne[2];
|
||||
const uint32_t n_seqs = (uint32_t) src2->ne[3];
|
||||
const float scale = 1.0f / sqrtf((float) s_v);
|
||||
uint32_t scale_u32;
|
||||
memcpy(&scale_u32, &scale, sizeof(scale_u32));
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
h,
|
||||
n_tokens,
|
||||
n_seqs,
|
||||
s_v * h * n_tokens * n_seqs,
|
||||
|
||||
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
||||
|
||||
(uint32_t) (src2->nb[1] / ggml_type_size(src2->type)),
|
||||
(uint32_t) (src2->nb[2] / ggml_type_size(src2->type)),
|
||||
(uint32_t) (src2->nb[3] / ggml_type_size(src2->type)),
|
||||
|
||||
(uint32_t) (src4->nb[1] / ggml_type_size(src4->type)),
|
||||
(uint32_t) (src4->nb[2] / ggml_type_size(src4->type)),
|
||||
(uint32_t) (src4->nb[3] / ggml_type_size(src4->type)),
|
||||
|
||||
(uint32_t) src0->ne[1],
|
||||
(uint32_t) (src2->ne[3] / src0->ne[3]),
|
||||
scale_u32,
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
|
||||
{ .binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(src2),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src2),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src2) },
|
||||
{ .binding = 3,
|
||||
.buffer = ggml_webgpu_tensor_buf(src3),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src3),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src3) },
|
||||
{ .binding = 4,
|
||||
.buffer = ggml_webgpu_tensor_buf(src4),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src4),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src4) },
|
||||
{ .binding = 5,
|
||||
.buffer = ggml_webgpu_tensor_buf(src5),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src5),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src5) },
|
||||
{ .binding = 6,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
||||
};
|
||||
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, h, n_seqs);
|
||||
}
|
||||
|
||||
static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
|
||||
ggml_tensor * src,
|
||||
ggml_tensor * idx,
|
||||
|
|
@ -1016,6 +1280,8 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
|
|||
ggml_tensor * src,
|
||||
ggml_tensor * idx,
|
||||
ggml_tensor * dst) {
|
||||
const bool float_parallel = src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16 || src->type == GGML_TYPE_I32;
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src,
|
||||
.src1 = nullptr,
|
||||
|
|
@ -1060,7 +1326,10 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
|
|||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
||||
};
|
||||
|
||||
uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], decisions->wg_size);
|
||||
uint32_t blocks_per_row = (uint32_t) (dst->ne[0] / (src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0 ? 4 : 1));
|
||||
uint32_t total_rows = (uint32_t) (dst->ne[1] * dst->ne[2] * dst->ne[3]);
|
||||
uint32_t total_threads = float_parallel ? blocks_per_row * total_rows : total_rows;
|
||||
uint32_t wg_x = CEIL_DIV(total_threads, decisions->wg_size);
|
||||
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
||||
}
|
||||
|
|
@ -1632,7 +1901,7 @@ static webgpu_command ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * s
|
|||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src,
|
||||
.dst = dst,
|
||||
.max_wg_size = WEBGPU_ROW_SPLIT_WG_SIZE,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
.inplace = inplace,
|
||||
};
|
||||
|
||||
|
|
@ -2176,6 +2445,8 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
|
|||
case GGML_OP_CPY:
|
||||
case GGML_OP_CONT:
|
||||
return ggml_webgpu_cpy(ctx, src0, node);
|
||||
case GGML_OP_SET:
|
||||
return ggml_webgpu_set(ctx, src0, src1, node);
|
||||
case GGML_OP_SET_ROWS:
|
||||
return ggml_webgpu_set_rows(ctx, src0, src1, node);
|
||||
case GGML_OP_GET_ROWS:
|
||||
|
|
@ -2219,6 +2490,12 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
|
|||
case GGML_OP_DIAG:
|
||||
case GGML_OP_TRI:
|
||||
return ggml_webgpu_unary_op(ctx, src0, node);
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
return ggml_webgpu_solve_tri(ctx, src0, src1, node);
|
||||
case GGML_OP_SSM_CONV:
|
||||
return ggml_webgpu_ssm_conv(ctx, src0, src1, node);
|
||||
case GGML_OP_GATED_DELTA_NET:
|
||||
return ggml_webgpu_gated_delta_net(ctx, src0, src1, src2, node->src[3], node->src[4], node->src[5], node);
|
||||
case GGML_OP_PAD:
|
||||
return ggml_webgpu_pad(ctx, src0, node);
|
||||
case GGML_OP_ARGMAX:
|
||||
|
|
@ -2957,7 +3234,7 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm
|
|||
/* .is_host = */ NULL, // defaults to false
|
||||
},
|
||||
/* .device = */
|
||||
dev,
|
||||
dev,
|
||||
/* .context = */ NULL
|
||||
};
|
||||
|
||||
|
|
@ -3040,6 +3317,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
|||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) ||
|
||||
(op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32);
|
||||
break;
|
||||
case GGML_OP_SET:
|
||||
supports_op = src0->type == src1->type && src0->type == op->type &&
|
||||
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32);
|
||||
break;
|
||||
case GGML_OP_SET_ROWS:
|
||||
supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32) && src0->type == GGML_TYPE_F32 &&
|
||||
(src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32));
|
||||
|
|
@ -3180,6 +3461,27 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
|||
}
|
||||
}
|
||||
break;
|
||||
case GGML_OP_TRI:
|
||||
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
|
||||
break;
|
||||
case GGML_OP_DIAG:
|
||||
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
|
||||
break;
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
|
||||
break;
|
||||
case GGML_OP_SSM_CONV:
|
||||
supports_op = op->type == GGML_TYPE_F32;
|
||||
break;
|
||||
case GGML_OP_GATED_DELTA_NET:
|
||||
{
|
||||
const uint32_t s_v = (uint32_t) src2->ne[0];
|
||||
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 &&
|
||||
src2->type == GGML_TYPE_F32 && op->src[3]->type == GGML_TYPE_F32 &&
|
||||
op->src[4]->type == GGML_TYPE_F32 && op->src[5]->type == GGML_TYPE_F32 &&
|
||||
s_v <= ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
}
|
||||
break;
|
||||
case GGML_OP_CLAMP:
|
||||
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
|
||||
break;
|
||||
|
|
@ -3201,12 +3503,6 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
|||
case GGML_OP_COS:
|
||||
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
|
||||
break;
|
||||
case GGML_OP_DIAG:
|
||||
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
|
||||
break;
|
||||
case GGML_OP_TRI:
|
||||
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
|
||||
break;
|
||||
case GGML_OP_PAD:
|
||||
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,132 @@
|
|||
@group(0) @binding(0)
|
||||
var<storage, read_write> src_q: array<f32>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> src_k: array<f32>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> src_v: array<f32>;
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<storage, read_write> src_g: array<f32>;
|
||||
|
||||
@group(0) @binding(4)
|
||||
var<storage, read_write> src_beta: array<f32>;
|
||||
|
||||
@group(0) @binding(5)
|
||||
var<storage, read_write> src_state: array<f32>;
|
||||
|
||||
@group(0) @binding(6)
|
||||
var<storage, read_write> dst: array<f32>;
|
||||
|
||||
struct Params {
|
||||
h: u32,
|
||||
n_tokens: u32,
|
||||
n_seqs: u32,
|
||||
s_off: u32,
|
||||
|
||||
sq1: u32,
|
||||
sq2: u32,
|
||||
sq3: u32,
|
||||
|
||||
sv1: u32,
|
||||
sv2: u32,
|
||||
sv3: u32,
|
||||
|
||||
sb1: u32,
|
||||
sb2: u32,
|
||||
sb3: u32,
|
||||
|
||||
neq1: u32,
|
||||
rq3: u32,
|
||||
scale: f32,
|
||||
};
|
||||
|
||||
@group(0) @binding(7)
|
||||
var<uniform> params: Params;
|
||||
|
||||
var<workgroup> sh_k: array<f32, S_V>;
|
||||
var<workgroup> sh_q: array<f32, S_V>;
|
||||
var<workgroup> sh_g: array<f32, S_V>;
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(
|
||||
@builtin(workgroup_id) workgroup_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>
|
||||
) {
|
||||
let head_id = workgroup_id.x;
|
||||
let seq_id = workgroup_id.y;
|
||||
let col = local_id.x;
|
||||
|
||||
let iq1 = head_id % params.neq1;
|
||||
let iq3 = seq_id / params.rq3;
|
||||
|
||||
let state_size = S_V * S_V;
|
||||
let state_base = (seq_id * params.h + head_id) * state_size;
|
||||
|
||||
var state: array<f32, S_V>;
|
||||
for (var i = 0u; i < S_V; i++) {
|
||||
state[i] = src_state[state_base + col * S_V + i];
|
||||
}
|
||||
|
||||
var attn_off = (seq_id * params.n_tokens * params.h + head_id) * S_V;
|
||||
|
||||
for (var t = 0u; t < params.n_tokens; t++) {
|
||||
let q_off = iq3 * params.sq3 + t * params.sq2 + iq1 * params.sq1;
|
||||
let k_off = q_off;
|
||||
let v_off = seq_id * params.sv3 + t * params.sv2 + head_id * params.sv1;
|
||||
let gb_off = seq_id * params.sb3 + t * params.sb2 + head_id * params.sb1;
|
||||
|
||||
sh_q[col] = src_q[q_off + col];
|
||||
sh_k[col] = src_k[k_off + col];
|
||||
|
||||
#ifdef KDA
|
||||
let g_base = gb_off * S_V;
|
||||
sh_g[col] = exp(src_g[g_base + col]);
|
||||
#endif
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
let v_val = src_v[v_off + col];
|
||||
let beta_val = src_beta[gb_off];
|
||||
|
||||
var kv_col = 0.0;
|
||||
var delta_col = 0.0;
|
||||
var attn_col = 0.0;
|
||||
|
||||
#ifdef KDA
|
||||
for (var i = 0u; i < S_V; i++) {
|
||||
kv_col += (sh_g[i] * state[i]) * sh_k[i];
|
||||
}
|
||||
|
||||
delta_col = (v_val - kv_col) * beta_val;
|
||||
|
||||
for (var i = 0u; i < S_V; i++) {
|
||||
state[i] = sh_g[i] * state[i] + sh_k[i] * delta_col;
|
||||
attn_col += state[i] * sh_q[i];
|
||||
}
|
||||
#else
|
||||
let g_val = exp(src_g[gb_off]);
|
||||
|
||||
for (var i = 0u; i < S_V; i++) {
|
||||
kv_col += state[i] * sh_k[i];
|
||||
}
|
||||
|
||||
delta_col = (v_val - g_val * kv_col) * beta_val;
|
||||
|
||||
for (var i = 0u; i < S_V; i++) {
|
||||
state[i] = g_val * state[i] + sh_k[i] * delta_col;
|
||||
attn_col += state[i] * sh_q[i];
|
||||
}
|
||||
#endif
|
||||
|
||||
dst[attn_off + col] = attn_col * params.scale;
|
||||
attn_off += S_V * params.h;
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
for (var i = 0u; i < S_V; i++) {
|
||||
dst[params.s_off + state_base + col * S_V + i] = state[i];
|
||||
}
|
||||
}
|
||||
|
|
@ -640,6 +640,35 @@ var<uniform> params: Params;
|
|||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
#ifdef FLOAT_PARALLEL
|
||||
let blocks_per_row = params.ne0 / BLOCK_SIZE;
|
||||
let row_count = params.n_rows * params.ne2 * params.ne3;
|
||||
|
||||
if (gid.x >= blocks_per_row * row_count) {
|
||||
return;
|
||||
}
|
||||
|
||||
let block_idx = gid.x % blocks_per_row;
|
||||
var row_idx = gid.x / blocks_per_row;
|
||||
let i_dst3 = row_idx / (params.ne2 * params.n_rows);
|
||||
|
||||
row_idx = row_idx % (params.ne2 * params.n_rows);
|
||||
let i_dst2 = row_idx / params.n_rows;
|
||||
let i_dst1 = row_idx % params.n_rows;
|
||||
|
||||
let i_idx2 = i_dst3 % params.idx2;
|
||||
let i_idx1 = i_dst2 % params.idx1;
|
||||
let i_idx0 = i_dst1;
|
||||
|
||||
let i_idx = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2;
|
||||
|
||||
let idx_val = u32(idx[i_idx]);
|
||||
|
||||
let i_src_row = params.offset_src + idx_val * params.stride_src1 + i_dst2 * params.stride_src2 + i_dst3 * params.stride_src3;
|
||||
let i_dst_row = params.offset_dst + i_dst1 * params.stride_dst1 + i_dst2 * params.stride_dst2 + i_dst3 * params.stride_dst3;
|
||||
|
||||
copy_elements(i_src_row, i_dst_row, block_idx);
|
||||
#else
|
||||
if (gid.x >= params.n_rows * params.ne2 * params.ne3) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -664,5 +693,5 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|||
for (var i: u32 = 0; i < params.ne0/BLOCK_SIZE; i++) {
|
||||
copy_elements(i_src_row, i_dst_row, i);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -81,11 +81,12 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
|||
}
|
||||
sum = scratch[0];
|
||||
|
||||
#ifdef OP_RMS_NORM
|
||||
#ifdef RMS_NORM
|
||||
let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps);
|
||||
#elif OP_L2_NORM
|
||||
#elif defined(L2_NORM)
|
||||
let scale = 1.0/max(sqrt(sum), params.eps);
|
||||
#endif
|
||||
|
||||
col = lid.x;
|
||||
for (var j: u32 = 0; j < elems; j++) {
|
||||
if (col >= params.ne0) {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,109 @@
|
|||
#ifdef TYPE_I32
|
||||
#define TYPE i32
|
||||
#else
|
||||
#define TYPE f32
|
||||
#endif
|
||||
|
||||
#ifndef INPLACE
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src0: array<TYPE>;
|
||||
#define SRC1_BINDING 1
|
||||
#else
|
||||
#define SRC1_BINDING 0
|
||||
#endif
|
||||
|
||||
#define DST_BINDING SRC1_BINDING + 1
|
||||
#define PARAMS_BINDING SRC1_BINDING + 2
|
||||
|
||||
@group(0) @binding(SRC1_BINDING)
|
||||
var<storage, read_write> src1: array<TYPE>;
|
||||
|
||||
@group(0) @binding(DST_BINDING)
|
||||
var<storage, read_write> dst: array<TYPE>;
|
||||
|
||||
struct Params {
|
||||
ne: u32,
|
||||
offset_src0: u32,
|
||||
offset_src1: u32,
|
||||
offset_view: u32,
|
||||
|
||||
stride_src10: u32,
|
||||
stride_src11: u32,
|
||||
stride_src12: u32,
|
||||
stride_src13: u32,
|
||||
|
||||
stride_dst10: u32,
|
||||
stride_dst11: u32,
|
||||
stride_dst12: u32,
|
||||
stride_dst13: u32,
|
||||
|
||||
src1_ne0: u32,
|
||||
src1_ne1: u32,
|
||||
src1_ne2: u32,
|
||||
src1_ne3: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(PARAMS_BINDING)
|
||||
var<uniform> params: Params;
|
||||
|
||||
fn decode_src1_coords(idx: u32) -> vec4<u32> {
|
||||
var i = idx;
|
||||
let plane = params.src1_ne2 * params.src1_ne1 * params.src1_ne0;
|
||||
let i3 = i / plane;
|
||||
i = i % plane;
|
||||
let row = params.src1_ne1 * params.src1_ne0;
|
||||
let i2 = i / row;
|
||||
i = i % row;
|
||||
let i1 = i / params.src1_ne0;
|
||||
let i0 = i % params.src1_ne0;
|
||||
return vec4<u32>(i0, i1, i2, i3);
|
||||
}
|
||||
|
||||
fn decode_view_coords(rel: u32) -> vec4<u32> {
|
||||
let i3 = rel / params.stride_dst13;
|
||||
let rem3 = rel % params.stride_dst13;
|
||||
let i2 = rem3 / params.stride_dst12;
|
||||
let rem2 = rem3 % params.stride_dst12;
|
||||
let i1 = rem2 / params.stride_dst11;
|
||||
let i0 = rem2 % params.stride_dst11;
|
||||
return vec4<u32>(i0, i1, i2, i3);
|
||||
}
|
||||
|
||||
fn view_rel_from_coords(coords: vec4<u32>) -> u32 {
|
||||
return coords.x * params.stride_dst10 + coords.y * params.stride_dst11 +
|
||||
coords.z * params.stride_dst12 + coords.w * params.stride_dst13;
|
||||
}
|
||||
|
||||
fn src1_idx_from_coords(coords: vec4<u32>) -> u32 {
|
||||
return coords.x * params.stride_src10 + coords.y * params.stride_src11 +
|
||||
coords.z * params.stride_src12 + coords.w * params.stride_src13;
|
||||
}
|
||||
|
||||
fn in_set_view(rel: u32, coords: vec4<u32>) -> bool {
|
||||
return view_rel_from_coords(coords) == rel;
|
||||
}
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x >= params.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
#ifdef INPLACE
|
||||
let coords = decode_src1_coords(gid.x);
|
||||
|
||||
let src1_idx = params.offset_src1 + src1_idx_from_coords(coords);
|
||||
let dst_idx = params.offset_view + view_rel_from_coords(coords);
|
||||
|
||||
dst[dst_idx] = src1[src1_idx];
|
||||
#else
|
||||
let rel = select(params.ne, gid.x - params.offset_view, gid.x >= params.offset_view);
|
||||
let coords = decode_view_coords(rel);
|
||||
|
||||
if (rel < params.stride_dst13 * params.src1_ne3 && in_set_view(rel, coords)) {
|
||||
dst[gid.x] = src1[params.offset_src1 + src1_idx_from_coords(coords)];
|
||||
} else {
|
||||
dst[gid.x] = src0[params.offset_src0 + gid.x];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
|
@ -0,0 +1,121 @@
|
|||
@group(0) @binding(0)
|
||||
var<storage, read_write> src0: array<f32>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> src1: array<f32>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> dst: array<f32>;
|
||||
|
||||
struct Params {
|
||||
offset_src0: u32,
|
||||
offset_src1: u32,
|
||||
offset_dst: u32,
|
||||
|
||||
stride_src00: u32,
|
||||
stride_src01: u32,
|
||||
stride_src02: u32,
|
||||
stride_src03: u32,
|
||||
|
||||
stride_src10: u32,
|
||||
stride_src11: u32,
|
||||
stride_src12: u32,
|
||||
stride_src13: u32,
|
||||
|
||||
stride_dst0: u32,
|
||||
stride_dst1: u32,
|
||||
stride_dst2: u32,
|
||||
stride_dst3: u32,
|
||||
|
||||
k: u32,
|
||||
ne2: u32,
|
||||
ne3: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
|
||||
var<workgroup> shA: array<f32, BATCH_N * N>;
|
||||
var<workgroup> shB: array<f32, BATCH_N * K_TILE>;
|
||||
|
||||
fn src0_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 {
|
||||
return params.offset_src0 +
|
||||
col * params.stride_src00 +
|
||||
row * params.stride_src01 +
|
||||
i2 * params.stride_src02 +
|
||||
i3 * params.stride_src03;
|
||||
}
|
||||
|
||||
fn src1_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 {
|
||||
return params.offset_src1 +
|
||||
col * params.stride_src10 +
|
||||
row * params.stride_src11 +
|
||||
i2 * params.stride_src12 +
|
||||
i3 * params.stride_src13;
|
||||
}
|
||||
|
||||
fn dst_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 {
|
||||
return params.offset_dst +
|
||||
col * params.stride_dst0 +
|
||||
row * params.stride_dst1 +
|
||||
i2 * params.stride_dst2 +
|
||||
i3 * params.stride_dst3;
|
||||
}
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(
|
||||
@builtin(workgroup_id) workgroup_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>
|
||||
) {
|
||||
let batch = workgroup_id.y;
|
||||
let col = workgroup_id.x * WG_SIZE + local_id.x;
|
||||
let i3 = batch / params.ne2;
|
||||
let i2 = batch % params.ne2;
|
||||
let active_lane = local_id.x < K_TILE;
|
||||
let active_col = active_lane && col < params.k;
|
||||
|
||||
var X: array<f32, N>;
|
||||
|
||||
for (var row_base = 0u; row_base < N; row_base += BATCH_N) {
|
||||
let cur_n = min(BATCH_N, N - row_base);
|
||||
|
||||
for (var i = local_id.x; i < cur_n * N; i += WG_SIZE) {
|
||||
let tile_row = i / N;
|
||||
let tile_col = i % N;
|
||||
shA[i] = src0[src0_idx(row_base + tile_row, tile_col, i2, i3)];
|
||||
}
|
||||
|
||||
for (var i = local_id.x; i < cur_n * K_TILE; i += WG_SIZE) {
|
||||
let tile_row = i / K_TILE;
|
||||
let tile_col = i % K_TILE;
|
||||
let global_col = workgroup_id.x * WG_SIZE + tile_col;
|
||||
let sh_idx = tile_row * K_TILE + tile_col;
|
||||
|
||||
if (global_col < params.k) {
|
||||
shB[sh_idx] = src1[src1_idx(row_base + tile_row, global_col, i2, i3)];
|
||||
} else {
|
||||
shB[sh_idx] = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
if (active_col) {
|
||||
for (var row_offset = 0u; row_offset < cur_n; row_offset++) {
|
||||
let r = row_base + row_offset;
|
||||
var b = shB[row_offset * K_TILE + local_id.x];
|
||||
let a_row = row_offset * N;
|
||||
|
||||
for (var t = 0u; t < r; t++) {
|
||||
b -= shA[a_row + t] * X[t];
|
||||
}
|
||||
|
||||
let x = b / shA[a_row + r];
|
||||
X[r] = x;
|
||||
dst[dst_idx(r, col, i2, i3)] = x;
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
@group(0) @binding(0)
|
||||
var<storage, read_write> src0: array<f32>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> src1: array<f32>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> dst: array<f32>;
|
||||
|
||||
struct Params {
|
||||
offset_src0: u32,
|
||||
offset_src1: u32,
|
||||
offset_dst: u32,
|
||||
|
||||
stride_src01: u32,
|
||||
stride_src02: u32,
|
||||
stride_src11: u32,
|
||||
|
||||
stride_dst0: u32,
|
||||
stride_dst1: u32,
|
||||
stride_dst2: u32,
|
||||
|
||||
nc: u32,
|
||||
nr: u32,
|
||||
n_t: u32,
|
||||
n_s: u32,
|
||||
token_tiles: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
|
||||
@compute @workgroup_size(BLOCK_SIZE, TOKENS_PER_WG)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
let i1 = gid.x;
|
||||
let tile_y = gid.y / TOKENS_PER_WG;
|
||||
let local_token = gid.y % TOKENS_PER_WG;
|
||||
let i3 = tile_y / params.token_tiles;
|
||||
let token_tile = tile_y % params.token_tiles;
|
||||
let i2 = token_tile * TOKENS_PER_WG + local_token;
|
||||
|
||||
if (i1 >= params.nr || i2 >= params.n_t || i3 >= params.n_s) {
|
||||
return;
|
||||
}
|
||||
|
||||
let src0_base = params.offset_src0 + i3 * params.stride_src02 + i2 + i1 * params.stride_src01;
|
||||
let src1_base = params.offset_src1 + i1 * params.stride_src11;
|
||||
|
||||
var sum = 0.0;
|
||||
|
||||
#ifdef VECTORIZED
|
||||
sum =
|
||||
src0[src0_base + 0u] * src1[src1_base + 0u] +
|
||||
src0[src0_base + 1u] * src1[src1_base + 1u] +
|
||||
src0[src0_base + 2u] * src1[src1_base + 2u] +
|
||||
src0[src0_base + 3u] * src1[src1_base + 3u];
|
||||
#else
|
||||
for (var i0 = 0u; i0 < params.nc; i0++) {
|
||||
sum += src0[src0_base + i0] * src1[src1_base + i0];
|
||||
}
|
||||
#endif
|
||||
|
||||
let dst_idx = params.offset_dst + i3 * params.stride_dst2 + i2 * params.stride_dst1 + i1 * params.stride_dst0;
|
||||
dst[dst_idx] = sum;
|
||||
}
|
||||
Loading…
Reference in New Issue