refactor argmax, set_rows
This commit is contained in:
parent
75e66cb49d
commit
8a13bbb11b
|
|
@ -1,15 +1,16 @@
|
|||
#ifndef GGML_WEBGPU_SHADER_LIB_HPP
|
||||
#define GGML_WEBGPU_SHADER_LIB_HPP
|
||||
|
||||
#include "ggml-wgsl-shaders.hpp"
|
||||
#include "ggml.h"
|
||||
#include "pre_wgsl.hpp"
|
||||
#include "ggml-wgsl-shaders.hpp"
|
||||
|
||||
#include <webgpu/webgpu_cpp.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#define GGML_WEBGPU_F16_SIZE_BYTES 2
|
||||
|
|
@ -50,6 +51,11 @@
|
|||
#define WEBGPU_MAX_WG_SIZE 288
|
||||
#define WEBGPU_MUL_MAT_WG_SIZE 256
|
||||
|
||||
// Same hash combine function as in boost
|
||||
template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) {
|
||||
seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
||||
}
|
||||
|
||||
struct ggml_webgpu_shader_lib_context {
|
||||
ggml_tensor * src0;
|
||||
ggml_tensor * src1;
|
||||
|
|
@ -64,16 +70,48 @@ struct webgpu_pipeline {
|
|||
std::shared_ptr<void> context = nullptr;
|
||||
};
|
||||
|
||||
/** Set Rows **/
|
||||
|
||||
struct ggml_webgpu_set_rows_pipeline_key {
|
||||
int dst_type;
|
||||
int vec4;
|
||||
int i64_idx;
|
||||
|
||||
bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const {
|
||||
return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_set_rows_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_set_rows_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.dst_type);
|
||||
ggml_webgpu_hash_combine(seed, key.vec4);
|
||||
ggml_webgpu_hash_combine(seed, key.i64_idx);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_set_rows_shader_decisions {
|
||||
bool vec4;
|
||||
bool i64_idx;
|
||||
uint32_t wg_size;
|
||||
};
|
||||
|
||||
class ggml_webgpu_shader_lib {
|
||||
wgpu::Device device;
|
||||
pre_wgsl::Preprocessor preprocessor;
|
||||
|
||||
std::unordered_map<int, webgpu_pipeline> sum_rows_pipelines; // key is fixed, no variants yet
|
||||
std::unordered_map<int, webgpu_pipeline> argmax_pipelines; // key is vec4
|
||||
|
||||
std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
|
||||
set_rows_pipelines;
|
||||
|
||||
public:
|
||||
ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
|
||||
|
||||
webgpu_pipeline get_sum_rows_pipeline(ggml_webgpu_shader_lib_context & context) {
|
||||
webgpu_pipeline get_sum_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
auto it = sum_rows_pipelines.find(1);
|
||||
if (it != sum_rows_pipelines.end()) {
|
||||
return it->second;
|
||||
|
|
@ -86,8 +124,74 @@ class ggml_webgpu_shader_lib {
|
|||
return sum_rows_pipelines[1];
|
||||
}
|
||||
|
||||
private:
|
||||
webgpu_pipeline get_argmax_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
bool vec4 = context.src0->ne[0] % 4 == 0;
|
||||
|
||||
auto it = argmax_pipelines.find(vec4);
|
||||
if (it != argmax_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
std::string variant = "argmax";
|
||||
std::vector<std::string> defines;
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
if (vec4) {
|
||||
defines.push_back("VEC4");
|
||||
variant += "_vec4";
|
||||
}
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_argmax, defines);
|
||||
argmax_pipelines[vec4] = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
return argmax_pipelines.at(vec4);
|
||||
}
|
||||
|
||||
webgpu_pipeline get_set_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_set_rows_pipeline_key key = { .dst_type = context.dst->type,
|
||||
.vec4 = context.src0->ne[0] % 4 == 0,
|
||||
.i64_idx = context.src1->type == GGML_TYPE_I64 };
|
||||
|
||||
auto it = set_rows_pipelines.find(key);
|
||||
if (it != set_rows_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "set_rows";
|
||||
|
||||
switch (context.dst->type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("DST_F32");
|
||||
variant += "_dstf32";
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("DST_F16");
|
||||
variant += "_dstf16";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported dst type for set_rows shader");
|
||||
}
|
||||
|
||||
if (key.vec4) {
|
||||
defines.push_back("VEC4");
|
||||
variant += "_vec4";
|
||||
}
|
||||
if (key.i64_idx) {
|
||||
defines.push_back("I64_IDX");
|
||||
variant += "_i64idx";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_set_rows, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_set_rows_shader_decisions>();
|
||||
decisions->vec4 = key.vec4;
|
||||
decisions->i64_idx = key.i64_idx;
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
set_rows_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
set_rows_pipelines[key].context = decisions;
|
||||
return set_rows_pipelines[key];
|
||||
}
|
||||
|
||||
private:
|
||||
static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
|
||||
std::string shader_code,
|
||||
std::string label) {
|
||||
|
|
@ -126,11 +230,6 @@ struct ggml_webgpu_processed_shader {
|
|||
std::shared_ptr<void> decisions;
|
||||
};
|
||||
|
||||
// Same hash combine function as in boost
|
||||
template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) {
|
||||
seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
||||
}
|
||||
|
||||
/** FlashAttention */
|
||||
|
||||
struct ggml_webgpu_flash_attn_pipeline_key {
|
||||
|
|
@ -436,73 +535,6 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_merge_shader(
|
|||
return result;
|
||||
}
|
||||
|
||||
/** Set Rows **/
|
||||
|
||||
struct ggml_webgpu_set_rows_pipeline_key {
|
||||
int dst_type;
|
||||
int vec4;
|
||||
int i64_idx;
|
||||
|
||||
bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const {
|
||||
return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_set_rows_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_set_rows_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.dst_type);
|
||||
ggml_webgpu_hash_combine(seed, key.vec4);
|
||||
ggml_webgpu_hash_combine(seed, key.i64_idx);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_set_rows_shader_lib_context {
|
||||
ggml_webgpu_set_rows_pipeline_key key;
|
||||
uint32_t max_wg_size;
|
||||
};
|
||||
|
||||
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_set_rows_shader(
|
||||
pre_wgsl::Preprocessor & preprocessor,
|
||||
const char * shader_src,
|
||||
const ggml_webgpu_set_rows_shader_lib_context & context) {
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "set_rows";
|
||||
|
||||
switch (context.key.dst_type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("DST_F32");
|
||||
variant += "_dstf32";
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("DST_F16");
|
||||
variant += "_dstf16";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported dst type for set_rows shader");
|
||||
}
|
||||
|
||||
if (context.key.vec4) {
|
||||
defines.push_back("VEC4");
|
||||
variant += "_vec";
|
||||
}
|
||||
if (context.key.i64_idx) {
|
||||
defines.push_back("I64_IDX");
|
||||
variant += "_i64idx";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
ggml_webgpu_processed_shader result;
|
||||
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
||||
result.variant = variant;
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
result.decisions = decisions;
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_webgpu_unary_pipeline_key {
|
||||
int type;
|
||||
int op;
|
||||
|
|
|
|||
|
|
@ -369,13 +369,10 @@ struct webgpu_context_struct {
|
|||
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
|
||||
flash_attn_pipelines;
|
||||
|
||||
std::unordered_map<int, webgpu_pipeline> argmax_pipelines; // key is vec4
|
||||
std::unordered_map<int, webgpu_pipeline> argsort_pipelines; // key is order (asc/desc)
|
||||
std::unordered_map<int, webgpu_pipeline> argsort_merge_pipelines; // key is order (asc/desc)
|
||||
std::unordered_map<int, webgpu_pipeline> cumsum_pipelines; // key is fixed, no variants yet
|
||||
|
||||
std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
|
||||
set_rows_pipelines;
|
||||
std::unordered_map<ggml_webgpu_get_rows_pipeline_key,
|
||||
webgpu_pipeline,
|
||||
ggml_webgpu_get_rows_pipeline_key_hash>
|
||||
|
|
@ -989,31 +986,16 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
|
|||
return std::nullopt;
|
||||
}
|
||||
|
||||
ggml_webgpu_set_rows_pipeline_key key = { .dst_type = dst->type,
|
||||
.vec4 = src->ne[0] % 4 == 0,
|
||||
.i64_idx = idx->type == GGML_TYPE_I64 };
|
||||
|
||||
ggml_webgpu_set_rows_shader_lib_context shader_lib_ctx = {
|
||||
.key = key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src, .src1 = idx, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline;
|
||||
auto it = ctx->set_rows_pipelines.find(key);
|
||||
if (it != ctx->set_rows_pipelines.end()) {
|
||||
pipeline = it->second;
|
||||
} else {
|
||||
ggml_webgpu_processed_shader processed =
|
||||
ggml_webgpu_preprocess_set_rows_shader(ctx->p, wgsl_set_rows, shader_lib_ctx);
|
||||
pipeline =
|
||||
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
|
||||
pipeline.context = processed.decisions;
|
||||
ctx->set_rows_pipelines.emplace(key, pipeline);
|
||||
}
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_set_rows_pipeline(shader_lib_ctx);
|
||||
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
auto * decisions = static_cast<ggml_webgpu_set_rows_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
std::optional<webgpu_pool_bufs> error_bufs = std::nullopt;
|
||||
if (key.i64_idx) {
|
||||
if (decisions->i64_idx) {
|
||||
error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs();
|
||||
if (error_bufs->host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
|
||||
error_bufs->host_buf.Unmap();
|
||||
|
|
@ -1051,13 +1033,13 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
|
|||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
||||
};
|
||||
|
||||
if (key.i64_idx) {
|
||||
if (decisions->i64_idx) {
|
||||
entries.push_back(
|
||||
{ .binding = 3, .buffer = error_bufs->dev_buf, .offset = 0, .size = error_bufs->dev_buf.GetSize() });
|
||||
}
|
||||
|
||||
uint32_t threads;
|
||||
if (key.vec4) {
|
||||
if (decisions->vec4) {
|
||||
threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4);
|
||||
} else {
|
||||
threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
|
||||
|
|
@ -1812,7 +1794,6 @@ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src,
|
|||
};
|
||||
|
||||
webgpu_pipeline pipeline;
|
||||
// TODO: remove guard once pipeline caches are per-thread
|
||||
auto it = ctx->scale_pipelines.find(key);
|
||||
if (it != ctx->scale_pipelines.end()) {
|
||||
pipeline = it->second;
|
||||
|
|
@ -1954,23 +1935,12 @@ static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src
|
|||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
||||
};
|
||||
|
||||
ggml_webgpu_generic_shader_lib_context shader_lib_ctx = {
|
||||
.vec4 = src->ne[0] % 4 == 0,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline;
|
||||
auto it = ctx->argmax_pipelines.find(shader_lib_ctx.vec4);
|
||||
if (it != ctx->argmax_pipelines.end()) {
|
||||
pipeline = it->second;
|
||||
} else {
|
||||
ggml_webgpu_processed_shader processed =
|
||||
ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_argmax, shader_lib_ctx, "argmax");
|
||||
pipeline =
|
||||
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
|
||||
ctx->argmax_pipelines.emplace(shader_lib_ctx.vec4, pipeline);
|
||||
}
|
||||
uint32_t wg_x = ggml_nelements(dst);
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_argmax_pipeline(shader_lib_ctx);
|
||||
uint32_t wg_x = ggml_nelements(dst);
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue