From ca4a8370bc1ebf267073cfa29067ebeff7ab8015 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Wed, 7 Jan 2026 05:03:32 -0600 Subject: [PATCH] vulkan: reject ops when a tensor is too large to allocate (#18646) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 29 +++++++++++++--------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 1f255b705e..d68735a040 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -14305,6 +14305,19 @@ static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const } static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + const vk_device& device = ggml_vk_get_device(ctx->device); + + // reject any tensors larger than the max buffer size + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (op->src[i] && ggml_nbytes(op->src[i]) > device->max_buffer_size) { + return false; + } + } + if (ggml_nbytes(op) > device->max_buffer_size) { + return false; + } + switch (op->op) { case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { @@ -14353,8 +14366,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_MUL_MAT_ID: { ggml_type src0_type = op->src[0]->type; - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - const vk_device& device = ggml_vk_get_device(ctx->device); if (op->op == GGML_OP_MUL_MAT_ID) { if (!device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) { // If there's not enough shared memory for row_ids and the result tile, fallback to CPU @@ -14415,8 +14426,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } case GGML_OP_FLASH_ATTN_EXT: { - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - auto device = ggml_vk_get_device(ctx->device); bool coopmat2 = device->coopmat2; uint32_t HSK = op->src[1]->ne[0]; uint32_t HSV = op->src[2]->ne[0]; @@ -14638,8 +14647,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) { return false; } - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - auto device = ggml_vk_get_device(ctx->device); // pipeline_argsort_large_f32 requires vulkan memory model. if (device->vulkan_memory_model) { return true; @@ -14652,8 +14659,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) { return false; } - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - auto device = ggml_vk_get_device(ctx->device); // We could potentially support larger, using argsort to sort the // whole thing. Not clear if this is needed. uint32_t min_pipeline = (uint32_t)log2f(float(op->ne[0])) + 1; @@ -14700,8 +14705,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]); case GGML_OP_CUMSUM: { - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - auto device = ggml_vk_get_device(ctx->device); if (device->subgroup_arithmetic && device->subgroup_require_full_support) { return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]); } @@ -14709,9 +14712,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } case GGML_OP_SOLVE_TRI: { - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - const vk_device& device = ggml_vk_get_device(ctx->device); - if (op->type != GGML_TYPE_F32 || op->src[0]->type != GGML_TYPE_F32) { return false; } @@ -14776,9 +14776,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return false; } - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - const vk_device& device = ggml_vk_get_device(ctx->device); - const uint32_t SPLIT_H = 16; size_t stateC_size = SPLIT_H * d_state * sizeof(float);