vulkan : move contiguous checks to device_supports_op (#17490)

* vulkan : remove op_supports_incontiguous and add missing constraints in device_supports_op

* im2col: remove contraints on src0 (kernel input)
This commit is contained in:
Acly 2025-11-27 06:54:19 +01:00 committed by GitHub
parent 142df17c9c
commit b78db3bd50
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 35 additions and 50 deletions

View File

@ -8687,41 +8687,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
GGML_UNUSED(src2);
}
static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
switch (op) {
case GGML_OP_CPY:
case GGML_OP_GET_ROWS:
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_ADD_ID:
case GGML_OP_CONCAT:
case GGML_OP_UPSCALE:
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_LOG:
case GGML_OP_CLAMP:
case GGML_OP_PAD:
case GGML_OP_REPEAT:
case GGML_OP_REPEAT_BACK:
case GGML_OP_ROPE:
case GGML_OP_RMS_NORM:
case GGML_OP_CONV_2D_DW:
case GGML_OP_IM2COL:
case GGML_OP_IM2COL_3D:
case GGML_OP_SET_ROWS:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
return true;
default:
return false;
}
}
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
@ -8806,7 +8771,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
std::cerr << "), " << ggml_op_name(op) << ")");
GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT
GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)); // NOLINT
GGML_ASSERT(dst->buffer != nullptr);
const uint64_t ne00 = src0->ne[0];
const uint64_t ne01 = src0->ne[1];
@ -8837,22 +8801,17 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
const bool op_supports_incontiguous = ggml_vk_op_supports_incontiguous(op);
vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0, op_supports_incontiguous);
vk_subbuffer src1_buf = use_src1 ? ggml_vk_tensor_subbuffer(ctx, src1, op_supports_incontiguous) : vk_subbuffer{};
vk_subbuffer src2_buf = use_src2 ? ggml_vk_tensor_subbuffer(ctx, src2, op_supports_incontiguous) : vk_subbuffer{};
vk_subbuffer src3_buf = use_src3 ? ggml_vk_tensor_subbuffer(ctx, src3, op_supports_incontiguous) : vk_subbuffer{};
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, op_supports_incontiguous);
vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0, true);
vk_subbuffer src1_buf = use_src1 ? ggml_vk_tensor_subbuffer(ctx, src1, true) : vk_subbuffer{};
vk_subbuffer src2_buf = use_src2 ? ggml_vk_tensor_subbuffer(ctx, src2, true) : vk_subbuffer{};
vk_subbuffer src3_buf = use_src3 ? ggml_vk_tensor_subbuffer(ctx, src3, true) : vk_subbuffer{};
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, true);
// Compute misalignment offset for descriptors and store it in in push constants.
init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, src3, dst);
std::array<uint32_t, 3> elements;
// Single call if dimension 2 is contiguous
GGML_ASSERT(op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1))));
switch (op) {
case GGML_OP_NORM:
case GGML_OP_RMS_NORM_BACK:
@ -13876,15 +13835,17 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
op->type == GGML_TYPE_F32;
case GGML_OP_SILU_BACK:
case GGML_OP_RMS_NORM_BACK:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_CLAMP:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_LEAKY_RELU:
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
return op->src[0]->type == GGML_TYPE_F32;
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_LOG:
return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
case GGML_OP_ARGSORT:
@ -13919,17 +13880,29 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return true;
case GGML_OP_UPSCALE:
case GGML_OP_ACC:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_CONCAT:
return ggml_type_size(op->src[0]->type) == ggml_type_size(GGML_TYPE_F32);
case GGML_OP_ADD1:
return (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32)
|| (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32)
|| (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16);
case GGML_OP_ARANGE:
case GGML_OP_FILL:
return op->type == GGML_TYPE_F32;
case GGML_OP_SCALE:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_PAD:
case GGML_OP_ROLL:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_DIAG_MASK_INF:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SOFT_MAX:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32
&& (!op->src[1] || (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16));
case GGML_OP_SOFT_MAX_BACK:
return true;
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32
&& ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32;
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
@ -13944,15 +13917,27 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return false;
}
case GGML_OP_ARGMAX:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_COUNT_EQUAL:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_I32
&& ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_I32;
case GGML_OP_IM2COL:
return ggml_is_contiguous(op->src[1])
&& op->src[1]->type == GGML_TYPE_F32
&& (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
case GGML_OP_IM2COL_3D:
return op->src[1]->type == GGML_TYPE_F32
&& (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
case GGML_OP_TIMESTEP_EMBEDDING:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_CONV_2D_DW:
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16)
&& op->src[1]->type == GGML_TYPE_F32;
case GGML_OP_POOL_2D:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
return true;
return true; // all inputs are contiguous, see ggml.c
case GGML_OP_SSM_SCAN:
{
for (int i = 0; i < 6; i++) {
@ -13993,7 +13978,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return true;
}
case GGML_OP_SSM_CONV:
return true;
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_CONV_TRANSPOSE_1D:
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
case GGML_OP_CONV_2D: