vulkan: add noncontiguous GLU support (#21081)
* vulkan: add noncontiguous GLU support * fix compile issue
This commit is contained in:
parent
1f5d15e665
commit
0eb4764182
|
|
@ -1112,6 +1112,16 @@ struct vk_op_glu_push_constants {
|
||||||
uint32_t mode; // 0: default, 1: swapped, 2: split
|
uint32_t mode; // 0: default, 1: swapped, 2: split
|
||||||
float alpha; // for swiglu_oai
|
float alpha; // for swiglu_oai
|
||||||
float limit;
|
float limit;
|
||||||
|
uint32_t nb01;
|
||||||
|
uint32_t nb02;
|
||||||
|
uint32_t nb03;
|
||||||
|
uint32_t ne01;
|
||||||
|
uint32_t ne02;
|
||||||
|
uint32_t nb11;
|
||||||
|
uint32_t nb12;
|
||||||
|
uint32_t nb13;
|
||||||
|
uint32_t ne11;
|
||||||
|
uint32_t ne12;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct vk_op_unary_push_constants {
|
struct vk_op_unary_push_constants {
|
||||||
|
|
@ -5044,7 +5054,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||||
} else {
|
} else {
|
||||||
device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
|
device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
|
||||||
}
|
}
|
||||||
vk::DeviceCreateInfo device_create_info;
|
vk::DeviceCreateInfo device_create_info{};
|
||||||
std::vector<const char *> device_extensions;
|
std::vector<const char *> device_extensions;
|
||||||
vk::PhysicalDeviceFeatures device_features = device->physical_device.getFeatures();
|
vk::PhysicalDeviceFeatures device_features = device->physical_device.getFeatures();
|
||||||
|
|
||||||
|
|
@ -5413,12 +5423,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||||
#endif
|
#endif
|
||||||
device->name = GGML_VK_NAME + std::to_string(idx);
|
device->name = GGML_VK_NAME + std::to_string(idx);
|
||||||
|
|
||||||
device_create_info = {
|
device_create_info
|
||||||
vk::DeviceCreateFlags(),
|
.setFlags(vk::DeviceCreateFlags())
|
||||||
device_queue_create_infos,
|
.setQueueCreateInfos(device_queue_create_infos)
|
||||||
{},
|
.setPEnabledExtensionNames(device_extensions);
|
||||||
device_extensions
|
|
||||||
};
|
|
||||||
device_create_info.setPNext(&device_features2);
|
device_create_info.setPNext(&device_features2);
|
||||||
device->device = device->physical_device.createDevice(device_create_info);
|
device->device = device->physical_device.createDevice(device_create_info);
|
||||||
|
|
||||||
|
|
@ -11048,8 +11056,6 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
||||||
const float alpha = op_params_f[2];
|
const float alpha = op_params_f[2];
|
||||||
const float limit = op_params_f[3];
|
const float limit = op_params_f[3];
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
||||||
|
|
||||||
if (!split) {
|
if (!split) {
|
||||||
GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
|
GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -11067,7 +11073,17 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
||||||
(uint32_t)dst->ne[0],
|
(uint32_t)dst->ne[0],
|
||||||
mode,
|
mode,
|
||||||
alpha,
|
alpha,
|
||||||
limit
|
limit,
|
||||||
|
(uint32_t)(src0->nb[1] / src0->nb[0]),
|
||||||
|
(uint32_t)(src0->nb[2] / src0->nb[0]),
|
||||||
|
(uint32_t)(src0->nb[3] / src0->nb[0]),
|
||||||
|
(uint32_t)src0->ne[1],
|
||||||
|
(uint32_t)src0->ne[2],
|
||||||
|
(uint32_t)(dst->nb[1] / dst->nb[0]),
|
||||||
|
(uint32_t)(dst->nb[2] / dst->nb[0]),
|
||||||
|
(uint32_t)(dst->nb[3] / dst->nb[0]),
|
||||||
|
(uint32_t)dst->ne[1],
|
||||||
|
(uint32_t)dst->ne[2]
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -15217,8 +15233,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
case GGML_GLU_OP_SWIGLU_OAI:
|
case GGML_GLU_OP_SWIGLU_OAI:
|
||||||
case GGML_GLU_OP_GEGLU_ERF:
|
case GGML_GLU_OP_GEGLU_ERF:
|
||||||
case GGML_GLU_OP_GEGLU_QUICK:
|
case GGML_GLU_OP_GEGLU_QUICK:
|
||||||
return ggml_is_contiguous(op->src[0]) &&
|
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
||||||
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
|
||||||
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
||||||
(op->src[0]->type == op->type);
|
(op->src[0]->type == op->type);
|
||||||
default:
|
default:
|
||||||
|
|
|
||||||
|
|
@ -16,4 +16,14 @@ layout (push_constant) uniform parameter
|
||||||
uint mode;
|
uint mode;
|
||||||
float alpha;
|
float alpha;
|
||||||
float limit;
|
float limit;
|
||||||
|
uint nb01;
|
||||||
|
uint nb02;
|
||||||
|
uint nb03;
|
||||||
|
uint ne01;
|
||||||
|
uint ne02;
|
||||||
|
uint nb11;
|
||||||
|
uint nb12;
|
||||||
|
uint nb13;
|
||||||
|
uint ne11;
|
||||||
|
uint ne12;
|
||||||
} p;
|
} p;
|
||||||
|
|
|
||||||
|
|
@ -8,22 +8,32 @@ void main() {
|
||||||
const uint row = i / p.ne20;
|
const uint row = i / p.ne20;
|
||||||
const uint col = i - row * p.ne20;
|
const uint col = i - row * p.ne20;
|
||||||
|
|
||||||
|
const uint i3 = row / (p.ne01 * p.ne02);
|
||||||
|
const uint i2 = (row % (p.ne01 * p.ne02)) / p.ne01;
|
||||||
|
const uint i1 = row % p.ne01;
|
||||||
|
const uint src_idx = i3 * p.nb03 + i2 * p.nb02 + i1 * p.nb01 + col;
|
||||||
|
|
||||||
|
const uint dst_i3 = row / (p.ne11 * p.ne12);
|
||||||
|
const uint dst_i2 = (row % (p.ne11 * p.ne12)) / p.ne11;
|
||||||
|
const uint dst_i1 = row % p.ne11;
|
||||||
|
const uint dst_idx = dst_i3 * p.nb13 + dst_i2 * p.nb12 + dst_i1 * p.nb11 + col;
|
||||||
|
|
||||||
if (p.mode == 0) {
|
if (p.mode == 0) {
|
||||||
// Default
|
// Default
|
||||||
const uint offset = p.ne00 / 2;
|
const uint offset = p.ne00 / 2;
|
||||||
const uint idx = row * p.ne00 + col;
|
const uint idx = src_idx;
|
||||||
|
|
||||||
data_d[row * offset + col] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset])));
|
data_d[dst_idx] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset])));
|
||||||
} else if (p.mode == 1) {
|
} else if (p.mode == 1) {
|
||||||
// Swapped
|
// Swapped
|
||||||
const uint offset = p.ne00 / 2;
|
const uint offset = p.ne00 / 2;
|
||||||
const uint idx = row * p.ne00 + col;
|
const uint idx = src_idx;
|
||||||
|
|
||||||
data_d[row * offset + col] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx])));
|
data_d[dst_idx] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx])));
|
||||||
} else {
|
} else {
|
||||||
// Split
|
// Split
|
||||||
const uint idx = row * p.ne00 + col;
|
const uint idx = src_idx;
|
||||||
|
|
||||||
data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx])));
|
data_d[dst_idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx])));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue