vulkan: Handle argsort with a large number of rows (#16851)
This commit is contained in:
parent
8b11deea46
commit
052df28b0e
|
|
@ -1082,6 +1082,7 @@ struct vk_op_soft_max_push_constants {
|
||||||
|
|
||||||
struct vk_op_argsort_push_constants {
|
struct vk_op_argsort_push_constants {
|
||||||
uint32_t ncols;
|
uint32_t ncols;
|
||||||
|
uint32_t nrows;
|
||||||
int32_t order;
|
int32_t order;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -8708,6 +8709,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||||
break;
|
break;
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
|
elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
|
||||||
|
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
|
||||||
break;
|
break;
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
{
|
{
|
||||||
|
|
@ -9954,9 +9956,11 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
||||||
int32_t * op_params = (int32_t *)dst->op_params;
|
int32_t * op_params = (int32_t *)dst->op_params;
|
||||||
|
|
||||||
uint32_t ncols = src0->ne[0];
|
uint32_t ncols = src0->ne[0];
|
||||||
|
uint32_t nrows = ggml_nrows(src0);
|
||||||
|
|
||||||
ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
|
ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
|
||||||
ncols,
|
ncols,
|
||||||
|
nrows,
|
||||||
op_params[0],
|
op_params[0],
|
||||||
}, dryrun);
|
}, dryrun);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ layout (binding = 1) buffer D {int data_d[];};
|
||||||
|
|
||||||
layout (push_constant) uniform parameter {
|
layout (push_constant) uniform parameter {
|
||||||
uint ncols;
|
uint ncols;
|
||||||
|
uint nrows;
|
||||||
uint order;
|
uint order;
|
||||||
} p;
|
} p;
|
||||||
|
|
||||||
|
|
@ -26,10 +27,9 @@ void swap(uint idx0, uint idx1) {
|
||||||
dst_row[idx1] = tmp;
|
dst_row[idx1] = tmp;
|
||||||
}
|
}
|
||||||
|
|
||||||
void argsort(bool needs_bounds_check) {
|
void argsort(bool needs_bounds_check, const uint row) {
|
||||||
// bitonic sort
|
// bitonic sort
|
||||||
const int col = int(gl_LocalInvocationID.x);
|
const int col = int(gl_LocalInvocationID.x);
|
||||||
const uint row = gl_WorkGroupID.y;
|
|
||||||
|
|
||||||
const uint row_offset = row * p.ncols;
|
const uint row_offset = row * p.ncols;
|
||||||
|
|
||||||
|
|
@ -72,8 +72,16 @@ void argsort(bool needs_bounds_check) {
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
if (p.ncols == BLOCK_SIZE) {
|
if (p.ncols == BLOCK_SIZE) {
|
||||||
argsort(false);
|
uint row = gl_WorkGroupID.y;
|
||||||
|
while (row < p.nrows) {
|
||||||
|
argsort(false, row);
|
||||||
|
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
argsort(true);
|
uint row = gl_WorkGroupID.y;
|
||||||
|
while (row < p.nrows) {
|
||||||
|
argsort(true, row);
|
||||||
|
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue