From 0eb4764182df1f031f92b4121f3bf8fbe8026565 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Sat, 28 Mar 2026 08:44:56 +0100 Subject: [PATCH] vulkan: add noncontiguous GLU support (#21081) * vulkan: add noncontiguous GLU support * fix compile issue --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 39 +++++++++++++------ .../ggml-vulkan/vulkan-shaders/glu_head.glsl | 10 +++++ .../ggml-vulkan/vulkan-shaders/glu_main.glsl | 22 ++++++++--- 3 files changed, 53 insertions(+), 18 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 221e6fa04e..15ed5b2a79 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1112,6 +1112,16 @@ struct vk_op_glu_push_constants { uint32_t mode; // 0: default, 1: swapped, 2: split float alpha; // for swiglu_oai float limit; + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + uint32_t ne01; + uint32_t ne02; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t ne11; + uint32_t ne12; }; struct vk_op_unary_push_constants { @@ -5044,7 +5054,7 @@ static vk_device ggml_vk_get_device(size_t idx) { } else { device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities}); } - vk::DeviceCreateInfo device_create_info; + vk::DeviceCreateInfo device_create_info{}; std::vector device_extensions; vk::PhysicalDeviceFeatures device_features = device->physical_device.getFeatures(); @@ -5413,12 +5423,10 @@ static vk_device ggml_vk_get_device(size_t idx) { #endif device->name = GGML_VK_NAME + std::to_string(idx); - device_create_info = { - vk::DeviceCreateFlags(), - device_queue_create_infos, - {}, - device_extensions - }; + device_create_info + .setFlags(vk::DeviceCreateFlags()) + .setQueueCreateInfos(device_queue_create_infos) + .setPEnabledExtensionNames(device_extensions); device_create_info.setPNext(&device_features2); device->device = device->physical_device.createDevice(device_create_info); @@ -11048,8 +11056,6 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const const float alpha = op_params_f[2]; const float limit = op_params_f[3]; - GGML_ASSERT(ggml_is_contiguous(src0)); - if (!split) { GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]); } else { @@ -11067,7 +11073,17 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const (uint32_t)dst->ne[0], mode, alpha, - limit + limit, + (uint32_t)(src0->nb[1] / src0->nb[0]), + (uint32_t)(src0->nb[2] / src0->nb[0]), + (uint32_t)(src0->nb[3] / src0->nb[0]), + (uint32_t)src0->ne[1], + (uint32_t)src0->ne[2], + (uint32_t)(dst->nb[1] / dst->nb[0]), + (uint32_t)(dst->nb[2] / dst->nb[0]), + (uint32_t)(dst->nb[3] / dst->nb[0]), + (uint32_t)dst->ne[1], + (uint32_t)dst->ne[2] }); } @@ -15217,8 +15233,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_GLU_OP_SWIGLU_OAI: case GGML_GLU_OP_GEGLU_ERF: case GGML_GLU_OP_GEGLU_QUICK: - return ggml_is_contiguous(op->src[0]) && - (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (op->src[0]->type == op->type); default: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl index 2168989340..95298922d8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl @@ -16,4 +16,14 @@ layout (push_constant) uniform parameter uint mode; float alpha; float limit; + uint nb01; + uint nb02; + uint nb03; + uint ne01; + uint ne02; + uint nb11; + uint nb12; + uint nb13; + uint ne11; + uint ne12; } p; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl index 85cf65a9ec..359461306a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl @@ -8,22 +8,32 @@ void main() { const uint row = i / p.ne20; const uint col = i - row * p.ne20; + const uint i3 = row / (p.ne01 * p.ne02); + const uint i2 = (row % (p.ne01 * p.ne02)) / p.ne01; + const uint i1 = row % p.ne01; + const uint src_idx = i3 * p.nb03 + i2 * p.nb02 + i1 * p.nb01 + col; + + const uint dst_i3 = row / (p.ne11 * p.ne12); + const uint dst_i2 = (row % (p.ne11 * p.ne12)) / p.ne11; + const uint dst_i1 = row % p.ne11; + const uint dst_idx = dst_i3 * p.nb13 + dst_i2 * p.nb12 + dst_i1 * p.nb11 + col; + if (p.mode == 0) { // Default const uint offset = p.ne00 / 2; - const uint idx = row * p.ne00 + col; + const uint idx = src_idx; - data_d[row * offset + col] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset]))); + data_d[dst_idx] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset]))); } else if (p.mode == 1) { // Swapped const uint offset = p.ne00 / 2; - const uint idx = row * p.ne00 + col; + const uint idx = src_idx; - data_d[row * offset + col] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx]))); + data_d[dst_idx] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx]))); } else { // Split - const uint idx = row * p.ne00 + col; + const uint idx = src_idx; - data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx]))); + data_d[dst_idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx]))); } }