refactor argmax, set_rows

This commit is contained in:
Reese Levine 2026-02-10 21:13:37 -08:00
parent 75e66cb49d
commit 8a13bbb11b
2 changed files with 118 additions and 116 deletions

View File

@ -1,15 +1,16 @@
#ifndef GGML_WEBGPU_SHADER_LIB_HPP
#define GGML_WEBGPU_SHADER_LIB_HPP
#include "ggml-wgsl-shaders.hpp"
#include "ggml.h"
#include "pre_wgsl.hpp"
#include "ggml-wgsl-shaders.hpp"
#include <webgpu/webgpu_cpp.h>
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#define GGML_WEBGPU_F16_SIZE_BYTES 2
@ -50,6 +51,11 @@
#define WEBGPU_MAX_WG_SIZE 288
#define WEBGPU_MUL_MAT_WG_SIZE 256
// Same hash combine function as in boost
template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) {
seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
struct ggml_webgpu_shader_lib_context {
ggml_tensor * src0;
ggml_tensor * src1;
@ -64,16 +70,48 @@ struct webgpu_pipeline {
std::shared_ptr<void> context = nullptr;
};
/** Set Rows **/
struct ggml_webgpu_set_rows_pipeline_key {
int dst_type;
int vec4;
int i64_idx;
bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const {
return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx;
}
};
struct ggml_webgpu_set_rows_pipeline_key_hash {
size_t operator()(const ggml_webgpu_set_rows_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.dst_type);
ggml_webgpu_hash_combine(seed, key.vec4);
ggml_webgpu_hash_combine(seed, key.i64_idx);
return seed;
}
};
struct ggml_webgpu_set_rows_shader_decisions {
bool vec4;
bool i64_idx;
uint32_t wg_size;
};
class ggml_webgpu_shader_lib {
wgpu::Device device;
pre_wgsl::Preprocessor preprocessor;
std::unordered_map<int, webgpu_pipeline> sum_rows_pipelines; // key is fixed, no variants yet
std::unordered_map<int, webgpu_pipeline> argmax_pipelines; // key is vec4
std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
set_rows_pipelines;
public:
ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
webgpu_pipeline get_sum_rows_pipeline(ggml_webgpu_shader_lib_context & context) {
webgpu_pipeline get_sum_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {
auto it = sum_rows_pipelines.find(1);
if (it != sum_rows_pipelines.end()) {
return it->second;
@ -86,8 +124,74 @@ class ggml_webgpu_shader_lib {
return sum_rows_pipelines[1];
}
private:
webgpu_pipeline get_argmax_pipeline(const ggml_webgpu_shader_lib_context & context) {
bool vec4 = context.src0->ne[0] % 4 == 0;
auto it = argmax_pipelines.find(vec4);
if (it != argmax_pipelines.end()) {
return it->second;
}
std::string variant = "argmax";
std::vector<std::string> defines;
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
if (vec4) {
defines.push_back("VEC4");
variant += "_vec4";
}
auto processed = preprocessor.preprocess(wgsl_argmax, defines);
argmax_pipelines[vec4] = ggml_webgpu_create_pipeline(device, processed, variant);
return argmax_pipelines.at(vec4);
}
webgpu_pipeline get_set_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_set_rows_pipeline_key key = { .dst_type = context.dst->type,
.vec4 = context.src0->ne[0] % 4 == 0,
.i64_idx = context.src1->type == GGML_TYPE_I64 };
auto it = set_rows_pipelines.find(key);
if (it != set_rows_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "set_rows";
switch (context.dst->type) {
case GGML_TYPE_F32:
defines.push_back("DST_F32");
variant += "_dstf32";
break;
case GGML_TYPE_F16:
defines.push_back("DST_F16");
variant += "_dstf16";
break;
default:
GGML_ABORT("Unsupported dst type for set_rows shader");
}
if (key.vec4) {
defines.push_back("VEC4");
variant += "_vec4";
}
if (key.i64_idx) {
defines.push_back("I64_IDX");
variant += "_i64idx";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_set_rows, defines);
auto decisions = std::make_shared<ggml_webgpu_set_rows_shader_decisions>();
decisions->vec4 = key.vec4;
decisions->i64_idx = key.i64_idx;
decisions->wg_size = context.max_wg_size;
set_rows_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant);
set_rows_pipelines[key].context = decisions;
return set_rows_pipelines[key];
}
private:
static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
std::string shader_code,
std::string label) {
@ -126,11 +230,6 @@ struct ggml_webgpu_processed_shader {
std::shared_ptr<void> decisions;
};
// Same hash combine function as in boost
template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) {
seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
/** FlashAttention */
struct ggml_webgpu_flash_attn_pipeline_key {
@ -436,73 +535,6 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_merge_shader(
return result;
}
/** Set Rows **/
struct ggml_webgpu_set_rows_pipeline_key {
int dst_type;
int vec4;
int i64_idx;
bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const {
return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx;
}
};
struct ggml_webgpu_set_rows_pipeline_key_hash {
size_t operator()(const ggml_webgpu_set_rows_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.dst_type);
ggml_webgpu_hash_combine(seed, key.vec4);
ggml_webgpu_hash_combine(seed, key.i64_idx);
return seed;
}
};
struct ggml_webgpu_set_rows_shader_lib_context {
ggml_webgpu_set_rows_pipeline_key key;
uint32_t max_wg_size;
};
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_set_rows_shader(
pre_wgsl::Preprocessor & preprocessor,
const char * shader_src,
const ggml_webgpu_set_rows_shader_lib_context & context) {
std::vector<std::string> defines;
std::string variant = "set_rows";
switch (context.key.dst_type) {
case GGML_TYPE_F32:
defines.push_back("DST_F32");
variant += "_dstf32";
break;
case GGML_TYPE_F16:
defines.push_back("DST_F16");
variant += "_dstf16";
break;
default:
GGML_ABORT("Unsupported dst type for set_rows shader");
}
if (context.key.vec4) {
defines.push_back("VEC4");
variant += "_vec";
}
if (context.key.i64_idx) {
defines.push_back("I64_IDX");
variant += "_i64idx";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
result.decisions = decisions;
return result;
}
struct ggml_webgpu_unary_pipeline_key {
int type;
int op;

View File

@ -369,13 +369,10 @@ struct webgpu_context_struct {
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
flash_attn_pipelines;
std::unordered_map<int, webgpu_pipeline> argmax_pipelines; // key is vec4
std::unordered_map<int, webgpu_pipeline> argsort_pipelines; // key is order (asc/desc)
std::unordered_map<int, webgpu_pipeline> argsort_merge_pipelines; // key is order (asc/desc)
std::unordered_map<int, webgpu_pipeline> cumsum_pipelines; // key is fixed, no variants yet
std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
set_rows_pipelines;
std::unordered_map<ggml_webgpu_get_rows_pipeline_key,
webgpu_pipeline,
ggml_webgpu_get_rows_pipeline_key_hash>
@ -989,31 +986,16 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
return std::nullopt;
}
ggml_webgpu_set_rows_pipeline_key key = { .dst_type = dst->type,
.vec4 = src->ne[0] % 4 == 0,
.i64_idx = idx->type == GGML_TYPE_I64 };
ggml_webgpu_set_rows_shader_lib_context shader_lib_ctx = {
.key = key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src, .src1 = idx, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
};
webgpu_pipeline pipeline;
auto it = ctx->set_rows_pipelines.find(key);
if (it != ctx->set_rows_pipelines.end()) {
pipeline = it->second;
} else {
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_set_rows_shader(ctx->p, wgsl_set_rows, shader_lib_ctx);
pipeline =
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
pipeline.context = processed.decisions;
ctx->set_rows_pipelines.emplace(key, pipeline);
}
webgpu_pipeline pipeline = ctx->shader_lib->get_set_rows_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
auto * decisions = static_cast<ggml_webgpu_set_rows_shader_decisions *>(pipeline.context.get());
std::optional<webgpu_pool_bufs> error_bufs = std::nullopt;
if (key.i64_idx) {
if (decisions->i64_idx) {
error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs();
if (error_bufs->host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
error_bufs->host_buf.Unmap();
@ -1051,13 +1033,13 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
};
if (key.i64_idx) {
if (decisions->i64_idx) {
entries.push_back(
{ .binding = 3, .buffer = error_bufs->dev_buf, .offset = 0, .size = error_bufs->dev_buf.GetSize() });
}
uint32_t threads;
if (key.vec4) {
if (decisions->vec4) {
threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4);
} else {
threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
@ -1812,7 +1794,6 @@ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src,
};
webgpu_pipeline pipeline;
// TODO: remove guard once pipeline caches are per-thread
auto it = ctx->scale_pipelines.find(key);
if (it != ctx->scale_pipelines.end()) {
pipeline = it->second;
@ -1954,23 +1935,12 @@ static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
};
ggml_webgpu_generic_shader_lib_context shader_lib_ctx = {
.vec4 = src->ne[0] % 4 == 0,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
};
webgpu_pipeline pipeline;
auto it = ctx->argmax_pipelines.find(shader_lib_ctx.vec4);
if (it != ctx->argmax_pipelines.end()) {
pipeline = it->second;
} else {
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_argmax, shader_lib_ctx, "argmax");
pipeline =
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
ctx->argmax_pipelines.emplace(shader_lib_ctx.vec4, pipeline);
}
uint32_t wg_x = ggml_nelements(dst);
webgpu_pipeline pipeline = ctx->shader_lib->get_argmax_pipeline(shader_lib_ctx);
uint32_t wg_x = ggml_nelements(dst);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}