flashattention and matrix multiplication moved to new format

This commit is contained in:
Reese Levine 2026-02-11 12:32:43 -08:00
parent a4e9b45306
commit ae6baf4714
4 changed files with 893 additions and 619 deletions

File diff suppressed because it is too large Load Diff

View File

@ -362,12 +362,9 @@ struct webgpu_context_struct {
std::unordered_map<ggml_webgpu_mul_mat_pipeline_key,
webgpu_pipeline,
ggml_webgpu_mul_mat_pipeline_key_hash>
mul_mat_pipelines; // src0_type, src1_type, vectorized
mul_mat_pipelines; // src0_type, src1_type, vectorized
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>>
mul_mat_vec_pipelines; // src0_type, src1_type, vectorized
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
flash_attn_pipelines;
mul_mat_vec_pipelines; // src0_type, src1_type, vectorized
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
@ -891,9 +888,7 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g
static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
.src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
};
webgpu_pipeline pipeline = ctx->shader_lib->get_pad_pipeline(shader_lib_ctx);
@ -957,7 +952,10 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
}
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src, .src1 = idx, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
.src0 = src,
.src1 = idx,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
};
webgpu_pipeline pipeline = ctx->shader_lib->get_set_rows_pipeline(shader_lib_ctx);
@ -1032,13 +1030,13 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
ggml_tensor * idx,
ggml_tensor * dst) {
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src,
.src1 = nullptr,
.dst = dst,
.src0 = src,
.src1 = nullptr,
.dst = dst,
.max_wg_size = WEBGPU_MAX_WG_SIZE,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx);
webgpu_pipeline pipeline = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
@ -1108,88 +1106,29 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
break;
}
int vectorized = 0;
if (use_fast) {
vectorized = (src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0);
if (is_vec) {
// We don't support vectorized mul_mat_vec for quantized types
vectorized = vectorized && (src0->type < 2);
}
}
// Create pipeline key
bool supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix;
ggml_webgpu_mul_mat_pipeline_key key = { .src0_type = src0->type,
.src1_type = src1->type,
.vectorized = use_fast ? vectorized : 0,
.is_vec = (use_fast && is_vec) ? 1 : 0,
.use_subgroup_matrix =
(use_fast && !is_vec && supports_subgroup_matrix) ? 1 : 0,
.register_tile =
(use_fast && !is_vec && !supports_subgroup_matrix) ? 1 : 0 };
// Build shader context
ggml_webgpu_mul_mat_shader_lib_context shader_lib_ctx = { .key = key,
.max_subgroup_size =
ctx->global_ctx->capabilities.max_subgroup_size,
.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m,
.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n,
.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k };
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src0,
.src1 = src1,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix,
.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m,
.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n,
.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k,
.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size,
};
// Get or create pipeline
webgpu_pipeline pipeline;
const char * shader_src = nullptr;
auto it = ctx->mul_mat_pipelines.find(key);
if (it != ctx->mul_mat_pipelines.end()) {
pipeline = it->second;
if (use_fast && is_vec) {
pipeline = ctx->shader_lib->get_mul_mat_vec_pipeline(shader_lib_ctx);
} else if (use_fast) {
pipeline = ctx->shader_lib->get_mul_mat_fast_pipeline(shader_lib_ctx);
} else {
// Select appropriate shader source based on key
if (!use_fast) {
// Use precompiled quantized shaders (mul_mat.tmpl.wgsl)
// These are the fallback for quantized types not supported by fast
// paths
shader_src = wgsl_mul_mat;
} else {
// Use JIT-compiled shader
if (is_vec) {
shader_src = wgsl_mul_mat_vec;
} else if (key.use_subgroup_matrix) {
shader_src = wgsl_mul_mat_subgroup_matrix;
} else {
shader_src = wgsl_mul_mat_reg_tile;
}
}
if (shader_src) {
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_mul_mat_shader(ctx->p, shader_src, shader_lib_ctx);
std::vector<wgpu::ConstantEntry> constants;
if (shader_lib_ctx.key.is_vec) {
auto decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(processed.decisions.get());
constants.push_back({ nullptr, "WORKGROUP_SIZE", static_cast<double>(decisions->wg_size) });
constants.push_back({ nullptr, "TILE_K", static_cast<double>(decisions->tile_k) });
constants.push_back({ nullptr, "OUTPUTS_PER_WG", static_cast<double>(decisions->outputs_per_wg) });
} else if (shader_lib_ctx.key.register_tile) {
auto decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(processed.decisions.get());
constants.push_back({ nullptr, "WORKGROUP_SIZE_M", static_cast<double>(decisions->wg_size_m) });
constants.push_back({ nullptr, "WORKGROUP_SIZE_N", static_cast<double>(decisions->wg_size_n) });
constants.push_back({ nullptr, "TILE_K", static_cast<double>(decisions->tile_k) });
}
// printf("DEBUG: Creating pipeline with variant='%s', "
// "constants.size()=%zu\n",
// processed.variant.c_str(), constants.size());
pipeline = ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(),
processed.variant.c_str(), constants);
pipeline.context = processed.decisions;
ctx->mul_mat_pipelines.emplace(key, pipeline);
}
pipeline = ctx->shader_lib->get_mul_mat_legacy_pipeline(shader_lib_ctx);
}
auto decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
// Build params
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
@ -1230,13 +1169,17 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
uint32_t wg_x = 1;
uint32_t wg_y = 1;
if (decisions->is_vec) {
if (use_fast && is_vec) {
auto decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
uint32_t batches = dst->ne[2] * dst->ne[3];
uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg);
uint32_t total_wg = output_groups * batches;
wg_x = total_wg % ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension);
} else if (use_fast) {
auto decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
// Fast-path tiled/subgroup calculations
uint32_t wg_m, wg_n;
if (decisions->use_subgroup_matrix) {
@ -1253,11 +1196,11 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
wg_n = CEIL_DIV(dst->ne[1], tile_n_s);
}
wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3];
} else {
// Non-fast-path quantized shaders (Q2_K, Q4_K, etc.)
// Use the value from decisions instead of hardcoded constant
wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], decisions->mul_mat_wg_size);
wg_y = 1;
} else { // legacy
auto decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t wg_size = decisions->wg_size;
wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size);
wg_y = 1;
}
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
@ -1347,40 +1290,22 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
bool kv_direct = (K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->global_ctx->capabilities.sg_mat_k == 0) &&
(K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
ggml_webgpu_flash_attn_pipeline_key key = {
.kv_type = K->type,
.head_dim_qk = (uint32_t) Q->ne[0],
.head_dim_v = (uint32_t) V->ne[0],
.kv_direct = kv_direct,
.has_mask = static_cast<bool>(has_mask),
.has_sinks = static_cast<bool>(has_sinks),
.uses_logit_softcap = logit_softcap != 0.0f,
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = Q,
.src1 = K,
.src2 = V,
.src3 = mask,
.src4 = sinks,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m,
.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n,
.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k,
.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size,
};
webgpu_pipeline pipeline;
auto it = ctx->flash_attn_pipelines.find(key);
if (it != ctx->flash_attn_pipelines.end()) {
pipeline = it->second;
} else {
ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = {
.key = key,
.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m,
.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n,
.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k,
.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size
};
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx);
pipeline =
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
pipeline.context = processed.decisions;
ctx->flash_attn_pipelines.emplace(key, pipeline);
}
webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context.get());
@ -1402,7 +1327,7 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s
.inplace = inplace,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_unary_pipeline(shader_lib_ctx);
webgpu_pipeline pipeline = ctx->shader_lib->get_unary_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
@ -1483,7 +1408,7 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
.overlap = flags.overlap,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx);
webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
@ -1860,19 +1785,18 @@ static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src
}
static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
bool is_top_k = dst->op == GGML_OP_TOP_K;
bool is_top_k = dst->op == GGML_OP_TOP_K;
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src,
.src1 = nullptr,
.dst = dst,
.src0 = src,
.src1 = nullptr,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
};
webgpu_pipeline argsort_pipeline = ctx->shader_lib->get_argsort_pipeline(shader_lib_ctx);
auto * argsort_decisions =
static_cast<ggml_webgpu_generic_shader_decisions *>(argsort_pipeline.context.get());
auto * argsort_decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(argsort_pipeline.context.get());
webgpu_pipeline argsort_merge_pipeline = ctx->shader_lib->get_argsort_merge_pipeline(shader_lib_ctx);
@ -2034,14 +1958,14 @@ static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src
};
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src,
.src1 = nullptr,
.dst = dst,
.src0 = src,
.src1 = nullptr,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_cumsum_pipeline(shader_lib_ctx);
uint32_t wg_x = ggml_nrows(dst);
uint32_t wg_x = ggml_nrows(dst);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}

View File

@ -34,7 +34,6 @@ struct MulMatParams {
broadcast3: u32
};
// SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included
@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
@group(0) @binding(2) var<storage, read_write> dst: array<DST_TYPE>; // M rows, N columns (transposed)
@ -48,14 +47,9 @@ fn get_local_m(thread_id: u32) -> u32 {
return thread_id % WORKGROUP_SIZE_M;
}
override WORKGROUP_SIZE_M: u32;
override WORKGROUP_SIZE_N: u32;
override TILE_K: u32;
override TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N;
override TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M;
override TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N;
const TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N;
const TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M;
const TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N;
var<workgroup> shmem: array<f16, TILE_SRC0_SHMEM + TILE_SRC1_SHMEM>;
@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
@ -142,4 +136,3 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
}
}
}

View File

@ -98,16 +98,13 @@ struct MulMatParams {
@group(0) @binding(3) var<uniform> params: MulMatParams;
override WORKGROUP_SIZE: u32;
override TILE_K: u32;
override OUTPUTS_PER_WG: u32;
override THREADS_PER_OUTPUT = WORKGROUP_SIZE / OUTPUTS_PER_WG;
const THREADS_PER_OUTPUT = WG_SIZE / OUTPUTS_PER_WG;
// Shared memory for collaborative loading and reduction
var<workgroup> shared_vector: array<SRC1_TYPE, TILE_K/VEC_SIZE>; // Cache vector tile
var<workgroup> partial_sums: array<f32, WORKGROUP_SIZE>; // For reduction
var<workgroup> partial_sums: array<f32, WG_SIZE>; // For reduction
@compute @workgroup_size(WORKGROUP_SIZE)
@compute @workgroup_size(WG_SIZE)
fn main(
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) wg_id: vec3<u32>,
@ -150,7 +147,7 @@ fn main(
let tile_size = min(TILE_K, params.k - k_tile);
// Cooperatively load vector tile into shared memory (all threads)
for (var i = thread_id * VEC_SIZE; i < tile_size; i += WORKGROUP_SIZE * VEC_SIZE) {
for (var i = thread_id * VEC_SIZE; i < tile_size; i += WG_SIZE * VEC_SIZE) {
shared_vector[i / VEC_SIZE] = src1[(src1_idx_base + k_tile + i) / VEC_SIZE];
}
@ -168,7 +165,7 @@ fn main(
workgroupBarrier();
let group_base = thread_group * THREADS_PER_OUTPUT;
let thread_base = group_base + thread_in_group;
var offset = THREADS_PER_OUTPUT / 2;
var offset: u32 = THREADS_PER_OUTPUT / 2;
while (offset > 0) {
if (thread_in_group < offset) {
partial_sums[thread_base] += partial_sums[thread_base + offset];