vulkan: Optimize GGML_OP_CUMSUM (#18417)
* vulkan: Optimize GGML_OP_CUMSUM There are two paths: The preexisting one that does a whole row per workgroup in a single shader, and one that splits each row into multiple blocks and does two passes. The first pass computes partials within a block, the second adds the block partials to compute the final result. The multipass shader is used when there are a small number of large rows. In the whole-row shader, handle multiple elements per invocation. * use 2 ELEM_PER_THREAD for AMD/Intel * address feedback
This commit is contained in:
parent
706e3f93a6
commit
18ddaea2ae
|
|
@ -765,6 +765,9 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_topk_f32[num_topk_pipelines];
|
||||
vk_pipeline pipeline_sum_rows_f32;
|
||||
vk_pipeline pipeline_cumsum_f32;
|
||||
vk_pipeline pipeline_cumsum_small_f32;
|
||||
vk_pipeline pipeline_cumsum_multipass1_f32;
|
||||
vk_pipeline pipeline_cumsum_multipass2_f32;
|
||||
vk_pipeline pipeline_argmax_f32;
|
||||
vk_pipeline pipeline_count_equal_i32;
|
||||
std::map<vk_solve_tri_pipeline_state, vk_pipeline> pipeline_solve_tri_f32;
|
||||
|
|
@ -4178,7 +4181,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size }, 1, true, true, device->subgroup_size);
|
||||
const uint32_t cumsum_elem_per_thread = (device->vendor_id == VK_VENDOR_ID_AMD || device->vendor_id == VK_VENDOR_ID_INTEL) ? 2 : 4;
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 256, device->subgroup_size, cumsum_elem_per_thread }, 1, true, true, device->subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cumsum_small_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size, 1 }, 1, true, true, device->subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cumsum_multipass1_f32, "cumsum_multipass1_f32", cumsum_multipass1_f32_len, cumsum_multipass1_f32_data, "main", 3, sizeof(vk_op_sum_rows_push_constants), {256, 1, 1}, { 256, device->subgroup_size }, 1, true, true, device->subgroup_size);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cumsum_multipass2_f32, "cumsum_multipass2_f32", cumsum_multipass2_f32_len, cumsum_multipass2_f32_data, "main", 3, sizeof(vk_op_sum_rows_push_constants), {256, 1, 1}, { 256, device->subgroup_size }, 1, true, true, device->subgroup_size);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
|
||||
|
||||
|
|
@ -8804,7 +8811,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
return nullptr;
|
||||
case GGML_OP_CUMSUM:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_cumsum_f32;
|
||||
if (src0->ne[0] <= 512) {
|
||||
return ctx->device->pipeline_cumsum_small_f32;
|
||||
} else {
|
||||
return ctx->device->pipeline_cumsum_f32;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_SOLVE_TRI:
|
||||
|
|
@ -10708,8 +10719,50 @@ static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
|||
}
|
||||
|
||||
static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, p);
|
||||
vk_op_sum_rows_push_constants pc = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
|
||||
// Use the single pass shader when the rows are small or there are enough rows to fill the GPU.
|
||||
// For fewer, larger rows, use the multipass shader to spread each row across SMs.
|
||||
if (dst->ne[0] <= 4096 || ggml_nrows(dst) >= ctx->device->shader_core_count) {
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, pc);
|
||||
return;
|
||||
}
|
||||
|
||||
// First pass computes partial sums within a block, and stores the last partial
|
||||
// to the temp buffer. Second pass sums the block partials from the temp buffer
|
||||
// and adds that to the result of the first pass.
|
||||
vk_pipeline pipeline1 = ctx->device->pipeline_cumsum_multipass1_f32;
|
||||
vk_pipeline pipeline2 = ctx->device->pipeline_cumsum_multipass2_f32;
|
||||
GGML_ASSERT(pipeline1 != nullptr && pipeline2 != nullptr);
|
||||
|
||||
ggml_pipeline_request_descriptor_sets(ctx, pipeline1, 1);
|
||||
ggml_pipeline_request_descriptor_sets(ctx, pipeline2, 1);
|
||||
|
||||
std::array<uint32_t, 3> elements;
|
||||
|
||||
elements[0] = dst->ne[0];
|
||||
elements[1] = (uint32_t)ggml_nrows(dst);
|
||||
elements[2] = 1;
|
||||
|
||||
size_t temp_size = sizeof(float) * elements[0] * ggml_nrows(dst);
|
||||
|
||||
if (ctx->prealloc_size_split_k < temp_size) {
|
||||
ctx->prealloc_size_split_k = temp_size;
|
||||
ggml_vk_preallocate_buffers(ctx, subctx);
|
||||
}
|
||||
|
||||
vk_subbuffer src_buf = ggml_vk_tensor_subbuffer(ctx, src0);
|
||||
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
|
||||
vk_subbuffer temp_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
|
||||
|
||||
if (ctx->prealloc_split_k_need_sync) {
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
}
|
||||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline1, {src_buf, dst_buf, temp_buf}, pc, elements);
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline2, {src_buf, dst_buf, temp_buf}, pc, elements);
|
||||
|
||||
ctx->prealloc_split_k_need_sync = true;
|
||||
}
|
||||
|
||||
static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
|||
|
||||
layout (constant_id = 0) const uint BLOCK_SIZE = 128;
|
||||
layout (constant_id = 1) const uint SUBGROUP_SIZE = 32;
|
||||
layout (constant_id = 2) const uint ELEM_PER_THREAD = 4;
|
||||
|
||||
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
||||
|
||||
|
|
@ -38,32 +39,45 @@ void main() {
|
|||
last_sum = 0;
|
||||
}
|
||||
|
||||
uint col = tid;
|
||||
uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE);
|
||||
uint col = tid * ELEM_PER_THREAD;
|
||||
uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE * ELEM_PER_THREAD);
|
||||
for (int i = 0; i < num_iter; ++i) {
|
||||
FLOAT_TYPE v = 0;
|
||||
if (col < p.n_cols) {
|
||||
v = FLOAT_TYPE(data_a[src_idx + col]);
|
||||
FLOAT_TYPE v[ELEM_PER_THREAD];
|
||||
FLOAT_TYPE thread_sum = 0;
|
||||
[[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
|
||||
if (col + j < p.n_cols) {
|
||||
thread_sum += FLOAT_TYPE(data_a[src_idx + col + j]);
|
||||
}
|
||||
v[j] = thread_sum;
|
||||
}
|
||||
v = subgroupInclusiveAdd(v);
|
||||
|
||||
thread_sum = subgroupExclusiveAdd(thread_sum);
|
||||
[[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
|
||||
v[j] += thread_sum;
|
||||
}
|
||||
// Store the largest partial sum for each subgroup, then add the partials for all
|
||||
// lower subgroups and the final partial sum from the previous iteration.
|
||||
if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) {
|
||||
partial[subgroup_id] = v;
|
||||
partial[subgroup_id] = v[ELEM_PER_THREAD - 1];
|
||||
}
|
||||
barrier();
|
||||
for (int j = 0; j < subgroup_id; ++j) {
|
||||
v += partial[j];
|
||||
for (int s = 0; s < subgroup_id; ++s) {
|
||||
[[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
|
||||
v[j] += partial[s];
|
||||
}
|
||||
}
|
||||
[[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
|
||||
v[j] += last_sum;
|
||||
}
|
||||
v += last_sum;
|
||||
barrier();
|
||||
if (tid == BLOCK_SIZE - 1) {
|
||||
last_sum = v;
|
||||
last_sum = v[ELEM_PER_THREAD - 1];
|
||||
}
|
||||
if (col < p.n_cols) {
|
||||
data_d[dst_idx + col] = D_TYPE(v);
|
||||
[[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
|
||||
if (col + j < p.n_cols) {
|
||||
data_d[dst_idx + col + j] = D_TYPE(v[j]);
|
||||
}
|
||||
}
|
||||
col += BLOCK_SIZE;
|
||||
col += BLOCK_SIZE * ELEM_PER_THREAD;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,60 @@
|
|||
#version 450
|
||||
|
||||
#include "types.glsl"
|
||||
#include "sum_rows.glsl"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
layout (binding = 2) writeonly buffer T {D_TYPE data_t[];};
|
||||
|
||||
layout (constant_id = 0) const uint BLOCK_SIZE = 128;
|
||||
layout (constant_id = 1) const uint SUBGROUP_SIZE = 32;
|
||||
|
||||
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
||||
|
||||
shared FLOAT_TYPE partial[BLOCK_SIZE / SUBGROUP_SIZE];
|
||||
|
||||
void main() {
|
||||
const uint row = gl_WorkGroupID.y;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
const uint col = gl_GlobalInvocationID.x;
|
||||
|
||||
const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);
|
||||
const uint i03_offset = i03 * p.ne01*p.ne02;
|
||||
const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);
|
||||
const uint i01 = row - i03_offset - i02*p.ne01;
|
||||
|
||||
const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
|
||||
const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;
|
||||
|
||||
uint subgroup_id = tid / SUBGROUP_SIZE;
|
||||
|
||||
FLOAT_TYPE v = 0;
|
||||
if (col < p.n_cols) {
|
||||
v = FLOAT_TYPE(data_a[src_idx + col]);
|
||||
}
|
||||
v = subgroupInclusiveAdd(v);
|
||||
|
||||
// Store the largest partial sum for each subgroup, then add the partials for all
|
||||
// lower subgroups and the final partial sum from the previous iteration.
|
||||
if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) {
|
||||
partial[subgroup_id] = v;
|
||||
}
|
||||
barrier();
|
||||
for (int j = 0; j < subgroup_id; ++j) {
|
||||
v += partial[j];
|
||||
}
|
||||
barrier();
|
||||
if (tid == BLOCK_SIZE - 1) {
|
||||
data_t[gl_WorkGroupID.x + gl_NumWorkGroups.x * row] = v;
|
||||
}
|
||||
if (col < p.n_cols) {
|
||||
data_d[dst_idx + col] = D_TYPE(v);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
#version 450
|
||||
|
||||
#include "types.glsl"
|
||||
#include "sum_rows.glsl"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
layout (binding = 1) buffer D {D_TYPE data_d[];};
|
||||
layout (binding = 2) readonly buffer T {D_TYPE data_t[];};
|
||||
|
||||
layout (constant_id = 0) const uint BLOCK_SIZE = 128;
|
||||
layout (constant_id = 1) const uint SUBGROUP_SIZE = 32;
|
||||
|
||||
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
||||
|
||||
shared FLOAT_TYPE temp[BLOCK_SIZE / SUBGROUP_SIZE];
|
||||
|
||||
void main() {
|
||||
const uint row = gl_WorkGroupID.y;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
|
||||
const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);
|
||||
const uint i03_offset = i03 * p.ne01*p.ne02;
|
||||
const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);
|
||||
const uint i01 = row - i03_offset - i02*p.ne01;
|
||||
|
||||
const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
|
||||
const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;
|
||||
|
||||
const uint col = gl_GlobalInvocationID.x;
|
||||
|
||||
float v = 0;
|
||||
// prefetch value we're adding to
|
||||
if (col < p.n_cols) {
|
||||
v = data_d[dst_idx + col];
|
||||
}
|
||||
|
||||
// compute the sum of all previous blocks
|
||||
uint c = tid;
|
||||
float sum = 0;
|
||||
while (c < gl_WorkGroupID.x) {
|
||||
sum += data_t[c + gl_NumWorkGroups.x * row];
|
||||
c += BLOCK_SIZE;
|
||||
}
|
||||
|
||||
sum = subgroupAdd(sum);
|
||||
if (gl_SubgroupInvocationID == 0) {
|
||||
temp[gl_SubgroupID] = sum;
|
||||
}
|
||||
barrier();
|
||||
sum = 0;
|
||||
[[unroll]] for (uint s = 0; s < BLOCK_SIZE / SUBGROUP_SIZE; ++s) {
|
||||
sum += temp[s];
|
||||
}
|
||||
|
||||
// Add the sum to what the first pass computed
|
||||
if (col < p.n_cols) {
|
||||
data_d[dst_idx + col] = v + sum;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -944,6 +944,8 @@ void process_shaders() {
|
|||
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
|
||||
string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("cumsum_multipass1_f32", "cumsum_multipass1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("cumsum_multipass2_f32", "cumsum_multipass2.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("count_experts", "count_experts.comp", merge_maps(base_dict, {{"A_TYPE", "uint"}, {"D_TYPE", "uint"}}));
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue