vulkan: Implement GGML_OP_CUMSUM (#17479)
This commit is contained in:
parent
583cb83416
commit
b3b03a7baf
|
|
@ -705,6 +705,7 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
|
||||
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
|
||||
vk_pipeline pipeline_sum_rows_f32;
|
||||
vk_pipeline pipeline_cumsum_f32;
|
||||
vk_pipeline pipeline_argmax_f32;
|
||||
vk_pipeline pipeline_count_equal_i32;
|
||||
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
|
||||
|
|
@ -3968,6 +3969,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size }, 1, true, true, device->subgroup_size);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
|
||||
|
||||
#define IM2COL(bda) \
|
||||
|
|
@ -8457,6 +8460,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
return ctx->device->pipeline_sum_rows_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_CUMSUM:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_cumsum_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_ARGMAX:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
|
||||
return ctx->device->pipeline_argmax_f32;
|
||||
|
|
@ -8821,6 +8829,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_SOFT_MAX_BACK:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_CUMSUM:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_ARGMAX:
|
||||
{
|
||||
|
|
@ -10150,6 +10159,11 @@ static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
|||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_MEAN, p);
|
||||
}
|
||||
|
||||
static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, p);
|
||||
}
|
||||
|
||||
static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f });
|
||||
}
|
||||
|
|
@ -11749,6 +11763,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
case GGML_OP_SUM_ROWS:
|
||||
ggml_vk_sum_rows(ctx, compute_ctx, src0, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_CUMSUM:
|
||||
ggml_vk_cumsum(ctx, compute_ctx, src0, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_MEAN:
|
||||
ggml_vk_mean(ctx, compute_ctx, src0, node);
|
||||
|
|
@ -13786,6 +13804,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_CUMSUM:
|
||||
{
|
||||
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||
auto device = ggml_vk_get_device(ctx->device);
|
||||
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
|
||||
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
case GGML_OP_ARGMAX:
|
||||
case GGML_OP_COUNT_EQUAL:
|
||||
case GGML_OP_IM2COL:
|
||||
|
|
@ -14436,6 +14463,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|||
tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
|
||||
} else if (tensor->op == GGML_OP_SUM_ROWS) {
|
||||
tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]);
|
||||
} else if (tensor->op == GGML_OP_CUMSUM) {
|
||||
tensor_clone = ggml_cumsum(ggml_ctx, src_clone[0]);
|
||||
} else if (tensor->op == GGML_OP_MEAN) {
|
||||
tensor_clone = ggml_mean(ggml_ctx, src_clone[0]);
|
||||
} else if (tensor->op == GGML_OP_ARGMAX) {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,69 @@
|
|||
#version 450
|
||||
|
||||
#include "types.glsl"
|
||||
#include "sum_rows.glsl"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
layout (constant_id = 0) const uint BLOCK_SIZE = 128;
|
||||
layout (constant_id = 1) const uint SUBGROUP_SIZE = 32;
|
||||
|
||||
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
||||
|
||||
shared FLOAT_TYPE partial[BLOCK_SIZE / SUBGROUP_SIZE];
|
||||
shared FLOAT_TYPE last_sum;
|
||||
|
||||
void main() {
|
||||
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
|
||||
const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);
|
||||
const uint i03_offset = i03 * p.ne01*p.ne02;
|
||||
const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);
|
||||
const uint i01 = row - i03_offset - i02*p.ne01;
|
||||
|
||||
const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
|
||||
const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;
|
||||
|
||||
uint subgroup_id = tid / SUBGROUP_SIZE;
|
||||
|
||||
if (tid == 0) {
|
||||
last_sum = 0;
|
||||
}
|
||||
|
||||
uint col = tid;
|
||||
uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE);
|
||||
for (int i = 0; i < num_iter; ++i) {
|
||||
FLOAT_TYPE v = 0;
|
||||
if (col < p.n_cols) {
|
||||
v = FLOAT_TYPE(data_a[src_idx + col]);
|
||||
}
|
||||
v = subgroupInclusiveAdd(v);
|
||||
|
||||
// Store the largest partial sum for each subgroup, then add the partials for all
|
||||
// lower subgroups and the final partial sum from the previous iteration.
|
||||
if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) {
|
||||
partial[subgroup_id] = v;
|
||||
}
|
||||
barrier();
|
||||
for (int j = 0; j < subgroup_id; ++j) {
|
||||
v += partial[j];
|
||||
}
|
||||
v += last_sum;
|
||||
barrier();
|
||||
if (tid == BLOCK_SIZE - 1) {
|
||||
last_sum = v;
|
||||
}
|
||||
if (col < p.n_cols) {
|
||||
data_d[dst_idx + col] = D_TYPE(v);
|
||||
}
|
||||
col += BLOCK_SIZE;
|
||||
}
|
||||
}
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
#version 450
|
||||
|
||||
#include "types.glsl"
|
||||
#include "sum_rows.glsl"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
|
|
@ -11,30 +12,6 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
|||
|
||||
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
uint n_cols;
|
||||
uint ne01, ne02;
|
||||
uint nb01, nb02, nb03;
|
||||
uint nb11, nb12, nb13;
|
||||
float weight;
|
||||
uint misalign_offsets;
|
||||
uint ne0_12mp, ne0_12L;
|
||||
uint ne0_1mp, ne0_1L;
|
||||
} p;
|
||||
|
||||
uint get_aoffset() { return p.misalign_offsets >> 16; }
|
||||
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
|
||||
|
||||
// see init_fastdiv_values in ggml-vulkan.cpp
|
||||
uint fastdiv(uint n, uint mp, uint L) {
|
||||
uint msbs, lsbs;
|
||||
// msbs = mulhi(n, mp)
|
||||
umulExtended(n, mp, msbs, lsbs);
|
||||
return (msbs + n) >> L;
|
||||
}
|
||||
|
||||
|
||||
shared FLOAT_TYPE tmp[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,25 @@
|
|||
|
||||
// vk_op_sum_rows_push_constants
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
uint n_cols;
|
||||
uint ne01, ne02;
|
||||
uint nb01, nb02, nb03;
|
||||
uint nb11, nb12, nb13;
|
||||
float weight;
|
||||
uint misalign_offsets;
|
||||
uint ne0_12mp, ne0_12L;
|
||||
uint ne0_1mp, ne0_1L;
|
||||
} p;
|
||||
|
||||
uint get_aoffset() { return p.misalign_offsets >> 16; }
|
||||
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
|
||||
|
||||
// see init_fastdiv_values in ggml-vulkan.cpp
|
||||
uint fastdiv(uint n, uint mp, uint L) {
|
||||
uint msbs, lsbs;
|
||||
// msbs = mulhi(n, mp)
|
||||
umulExtended(n, mp, msbs, lsbs);
|
||||
return (msbs + n) >> L;
|
||||
}
|
||||
|
||||
|
|
@ -916,6 +916,7 @@ void process_shaders() {
|
|||
string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
|
||||
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
|
||||
string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
for (std::string dim_str : {"", "_3d"}) {
|
||||
for (bool bda : {false, true}) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue