flashattention and matrix multiplication moved to new format
This commit is contained in:
parent
a4e9b45306
commit
ae6baf4714
File diff suppressed because it is too large
Load Diff
|
|
@ -362,12 +362,9 @@ struct webgpu_context_struct {
|
|||
std::unordered_map<ggml_webgpu_mul_mat_pipeline_key,
|
||||
webgpu_pipeline,
|
||||
ggml_webgpu_mul_mat_pipeline_key_hash>
|
||||
mul_mat_pipelines; // src0_type, src1_type, vectorized
|
||||
mul_mat_pipelines; // src0_type, src1_type, vectorized
|
||||
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>>
|
||||
mul_mat_vec_pipelines; // src0_type, src1_type, vectorized
|
||||
|
||||
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
|
||||
flash_attn_pipelines;
|
||||
mul_mat_vec_pipelines; // src0_type, src1_type, vectorized
|
||||
|
||||
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
|
||||
|
||||
|
|
@ -891,9 +888,7 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g
|
|||
|
||||
static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src,
|
||||
.dst = dst,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
|
||||
.src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_pad_pipeline(shader_lib_ctx);
|
||||
|
|
@ -957,7 +952,10 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
|
|||
}
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src, .src1 = idx, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
|
||||
.src0 = src,
|
||||
.src1 = idx,
|
||||
.dst = dst,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_set_rows_pipeline(shader_lib_ctx);
|
||||
|
|
@ -1032,13 +1030,13 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
|
|||
ggml_tensor * idx,
|
||||
ggml_tensor * dst) {
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src,
|
||||
.src1 = nullptr,
|
||||
.dst = dst,
|
||||
.src0 = src,
|
||||
.src1 = nullptr,
|
||||
.dst = dst,
|
||||
.max_wg_size = WEBGPU_MAX_WG_SIZE,
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx);
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx);
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
||||
|
|
@ -1108,88 +1106,29 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
|||
break;
|
||||
}
|
||||
|
||||
int vectorized = 0;
|
||||
if (use_fast) {
|
||||
vectorized = (src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0);
|
||||
if (is_vec) {
|
||||
// We don't support vectorized mul_mat_vec for quantized types
|
||||
vectorized = vectorized && (src0->type < 2);
|
||||
}
|
||||
}
|
||||
|
||||
// Create pipeline key
|
||||
bool supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix;
|
||||
ggml_webgpu_mul_mat_pipeline_key key = { .src0_type = src0->type,
|
||||
.src1_type = src1->type,
|
||||
.vectorized = use_fast ? vectorized : 0,
|
||||
.is_vec = (use_fast && is_vec) ? 1 : 0,
|
||||
.use_subgroup_matrix =
|
||||
(use_fast && !is_vec && supports_subgroup_matrix) ? 1 : 0,
|
||||
.register_tile =
|
||||
(use_fast && !is_vec && !supports_subgroup_matrix) ? 1 : 0 };
|
||||
|
||||
// Build shader context
|
||||
ggml_webgpu_mul_mat_shader_lib_context shader_lib_ctx = { .key = key,
|
||||
.max_subgroup_size =
|
||||
ctx->global_ctx->capabilities.max_subgroup_size,
|
||||
.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m,
|
||||
.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n,
|
||||
.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k };
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src0,
|
||||
.src1 = src1,
|
||||
.dst = dst,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix,
|
||||
.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m,
|
||||
.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n,
|
||||
.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k,
|
||||
.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size,
|
||||
};
|
||||
|
||||
// Get or create pipeline
|
||||
webgpu_pipeline pipeline;
|
||||
const char * shader_src = nullptr;
|
||||
|
||||
auto it = ctx->mul_mat_pipelines.find(key);
|
||||
if (it != ctx->mul_mat_pipelines.end()) {
|
||||
pipeline = it->second;
|
||||
if (use_fast && is_vec) {
|
||||
pipeline = ctx->shader_lib->get_mul_mat_vec_pipeline(shader_lib_ctx);
|
||||
} else if (use_fast) {
|
||||
pipeline = ctx->shader_lib->get_mul_mat_fast_pipeline(shader_lib_ctx);
|
||||
} else {
|
||||
// Select appropriate shader source based on key
|
||||
if (!use_fast) {
|
||||
// Use precompiled quantized shaders (mul_mat.tmpl.wgsl)
|
||||
// These are the fallback for quantized types not supported by fast
|
||||
// paths
|
||||
shader_src = wgsl_mul_mat;
|
||||
} else {
|
||||
// Use JIT-compiled shader
|
||||
if (is_vec) {
|
||||
shader_src = wgsl_mul_mat_vec;
|
||||
} else if (key.use_subgroup_matrix) {
|
||||
shader_src = wgsl_mul_mat_subgroup_matrix;
|
||||
} else {
|
||||
shader_src = wgsl_mul_mat_reg_tile;
|
||||
}
|
||||
}
|
||||
|
||||
if (shader_src) {
|
||||
ggml_webgpu_processed_shader processed =
|
||||
ggml_webgpu_preprocess_mul_mat_shader(ctx->p, shader_src, shader_lib_ctx);
|
||||
|
||||
std::vector<wgpu::ConstantEntry> constants;
|
||||
if (shader_lib_ctx.key.is_vec) {
|
||||
auto decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(processed.decisions.get());
|
||||
constants.push_back({ nullptr, "WORKGROUP_SIZE", static_cast<double>(decisions->wg_size) });
|
||||
constants.push_back({ nullptr, "TILE_K", static_cast<double>(decisions->tile_k) });
|
||||
constants.push_back({ nullptr, "OUTPUTS_PER_WG", static_cast<double>(decisions->outputs_per_wg) });
|
||||
} else if (shader_lib_ctx.key.register_tile) {
|
||||
auto decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(processed.decisions.get());
|
||||
constants.push_back({ nullptr, "WORKGROUP_SIZE_M", static_cast<double>(decisions->wg_size_m) });
|
||||
constants.push_back({ nullptr, "WORKGROUP_SIZE_N", static_cast<double>(decisions->wg_size_n) });
|
||||
constants.push_back({ nullptr, "TILE_K", static_cast<double>(decisions->tile_k) });
|
||||
}
|
||||
// printf("DEBUG: Creating pipeline with variant='%s', "
|
||||
// "constants.size()=%zu\n",
|
||||
// processed.variant.c_str(), constants.size());
|
||||
|
||||
pipeline = ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(),
|
||||
processed.variant.c_str(), constants);
|
||||
pipeline.context = processed.decisions;
|
||||
ctx->mul_mat_pipelines.emplace(key, pipeline);
|
||||
}
|
||||
pipeline = ctx->shader_lib->get_mul_mat_legacy_pipeline(shader_lib_ctx);
|
||||
}
|
||||
|
||||
auto decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
// Build params
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||
|
|
@ -1230,13 +1169,17 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
|||
uint32_t wg_x = 1;
|
||||
uint32_t wg_y = 1;
|
||||
|
||||
if (decisions->is_vec) {
|
||||
if (use_fast && is_vec) {
|
||||
auto decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
uint32_t batches = dst->ne[2] * dst->ne[3];
|
||||
uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg);
|
||||
uint32_t total_wg = output_groups * batches;
|
||||
wg_x = total_wg % ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension);
|
||||
} else if (use_fast) {
|
||||
auto decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
// Fast-path tiled/subgroup calculations
|
||||
uint32_t wg_m, wg_n;
|
||||
if (decisions->use_subgroup_matrix) {
|
||||
|
|
@ -1253,11 +1196,11 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
|||
wg_n = CEIL_DIV(dst->ne[1], tile_n_s);
|
||||
}
|
||||
wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3];
|
||||
} else {
|
||||
// Non-fast-path quantized shaders (Q2_K, Q4_K, etc.)
|
||||
// Use the value from decisions instead of hardcoded constant
|
||||
wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], decisions->mul_mat_wg_size);
|
||||
wg_y = 1;
|
||||
} else { // legacy
|
||||
auto decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
uint32_t wg_size = decisions->wg_size;
|
||||
wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size);
|
||||
wg_y = 1;
|
||||
}
|
||||
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
|
||||
|
|
@ -1347,40 +1290,22 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
|
|||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
||||
|
||||
bool kv_direct = (K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->global_ctx->capabilities.sg_mat_k == 0) &&
|
||||
(K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
|
||||
|
||||
ggml_webgpu_flash_attn_pipeline_key key = {
|
||||
.kv_type = K->type,
|
||||
.head_dim_qk = (uint32_t) Q->ne[0],
|
||||
.head_dim_v = (uint32_t) V->ne[0],
|
||||
.kv_direct = kv_direct,
|
||||
.has_mask = static_cast<bool>(has_mask),
|
||||
.has_sinks = static_cast<bool>(has_sinks),
|
||||
.uses_logit_softcap = logit_softcap != 0.0f,
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = Q,
|
||||
.src1 = K,
|
||||
.src2 = V,
|
||||
.src3 = mask,
|
||||
.src4 = sinks,
|
||||
.dst = dst,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
|
||||
.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m,
|
||||
.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n,
|
||||
.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k,
|
||||
.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size,
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline;
|
||||
auto it = ctx->flash_attn_pipelines.find(key);
|
||||
if (it != ctx->flash_attn_pipelines.end()) {
|
||||
pipeline = it->second;
|
||||
} else {
|
||||
ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = {
|
||||
.key = key,
|
||||
.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m,
|
||||
.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n,
|
||||
.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k,
|
||||
.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
|
||||
.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size
|
||||
};
|
||||
|
||||
ggml_webgpu_processed_shader processed =
|
||||
ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx);
|
||||
pipeline =
|
||||
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
|
||||
pipeline.context = processed.decisions;
|
||||
ctx->flash_attn_pipelines.emplace(key, pipeline);
|
||||
}
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx);
|
||||
|
||||
auto * decisions = static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
|
|
@ -1402,7 +1327,7 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s
|
|||
.inplace = inplace,
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_unary_pipeline(shader_lib_ctx);
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_unary_pipeline(shader_lib_ctx);
|
||||
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
|
|
@ -1483,7 +1408,7 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
|
|||
.overlap = flags.overlap,
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx);
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx);
|
||||
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
|
|
@ -1860,19 +1785,18 @@ static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src
|
|||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
bool is_top_k = dst->op == GGML_OP_TOP_K;
|
||||
bool is_top_k = dst->op == GGML_OP_TOP_K;
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src,
|
||||
.src1 = nullptr,
|
||||
.dst = dst,
|
||||
.src0 = src,
|
||||
.src1 = nullptr,
|
||||
.dst = dst,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
|
||||
};
|
||||
|
||||
webgpu_pipeline argsort_pipeline = ctx->shader_lib->get_argsort_pipeline(shader_lib_ctx);
|
||||
auto * argsort_decisions =
|
||||
static_cast<ggml_webgpu_generic_shader_decisions *>(argsort_pipeline.context.get());
|
||||
auto * argsort_decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(argsort_pipeline.context.get());
|
||||
|
||||
webgpu_pipeline argsort_merge_pipeline = ctx->shader_lib->get_argsort_merge_pipeline(shader_lib_ctx);
|
||||
|
||||
|
|
@ -2034,14 +1958,14 @@ static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src
|
|||
};
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
.src0 = src,
|
||||
.src1 = nullptr,
|
||||
.dst = dst,
|
||||
.src0 = src,
|
||||
.src1 = nullptr,
|
||||
.dst = dst,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_cumsum_pipeline(shader_lib_ctx);
|
||||
uint32_t wg_x = ggml_nrows(dst);
|
||||
uint32_t wg_x = ggml_nrows(dst);
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -34,7 +34,6 @@ struct MulMatParams {
|
|||
broadcast3: u32
|
||||
};
|
||||
|
||||
// SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included
|
||||
@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
|
||||
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
|
||||
@group(0) @binding(2) var<storage, read_write> dst: array<DST_TYPE>; // M rows, N columns (transposed)
|
||||
|
|
@ -48,14 +47,9 @@ fn get_local_m(thread_id: u32) -> u32 {
|
|||
return thread_id % WORKGROUP_SIZE_M;
|
||||
}
|
||||
|
||||
override WORKGROUP_SIZE_M: u32;
|
||||
override WORKGROUP_SIZE_N: u32;
|
||||
override TILE_K: u32;
|
||||
|
||||
override TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N;
|
||||
override TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M;
|
||||
override TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N;
|
||||
|
||||
const TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N;
|
||||
const TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M;
|
||||
const TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N;
|
||||
var<workgroup> shmem: array<f16, TILE_SRC0_SHMEM + TILE_SRC1_SHMEM>;
|
||||
|
||||
@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
|
||||
|
|
@ -142,4 +136,3 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -98,16 +98,13 @@ struct MulMatParams {
|
|||
|
||||
@group(0) @binding(3) var<uniform> params: MulMatParams;
|
||||
|
||||
override WORKGROUP_SIZE: u32;
|
||||
override TILE_K: u32;
|
||||
override OUTPUTS_PER_WG: u32;
|
||||
override THREADS_PER_OUTPUT = WORKGROUP_SIZE / OUTPUTS_PER_WG;
|
||||
const THREADS_PER_OUTPUT = WG_SIZE / OUTPUTS_PER_WG;
|
||||
|
||||
// Shared memory for collaborative loading and reduction
|
||||
var<workgroup> shared_vector: array<SRC1_TYPE, TILE_K/VEC_SIZE>; // Cache vector tile
|
||||
var<workgroup> partial_sums: array<f32, WORKGROUP_SIZE>; // For reduction
|
||||
var<workgroup> partial_sums: array<f32, WG_SIZE>; // For reduction
|
||||
|
||||
@compute @workgroup_size(WORKGROUP_SIZE)
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
|
|
@ -150,7 +147,7 @@ fn main(
|
|||
let tile_size = min(TILE_K, params.k - k_tile);
|
||||
|
||||
// Cooperatively load vector tile into shared memory (all threads)
|
||||
for (var i = thread_id * VEC_SIZE; i < tile_size; i += WORKGROUP_SIZE * VEC_SIZE) {
|
||||
for (var i = thread_id * VEC_SIZE; i < tile_size; i += WG_SIZE * VEC_SIZE) {
|
||||
shared_vector[i / VEC_SIZE] = src1[(src1_idx_base + k_tile + i) / VEC_SIZE];
|
||||
}
|
||||
|
||||
|
|
@ -168,7 +165,7 @@ fn main(
|
|||
workgroupBarrier();
|
||||
let group_base = thread_group * THREADS_PER_OUTPUT;
|
||||
let thread_base = group_base + thread_in_group;
|
||||
var offset = THREADS_PER_OUTPUT / 2;
|
||||
var offset: u32 = THREADS_PER_OUTPUT / 2;
|
||||
while (offset > 0) {
|
||||
if (thread_in_group < offset) {
|
||||
partial_sums[thread_base] += partial_sums[thread_base + offset];
|
||||
|
|
|
|||
Loading…
Reference in New Issue