Merge feb3eef27e into 05fa625eac
This commit is contained in:
commit
28daa074ea
|
|
@ -944,6 +944,7 @@ struct vk_mat_mat_push_constants {
|
|||
uint32_t M; uint32_t N; uint32_t K;
|
||||
uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
|
||||
uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
|
||||
uint32_t base_work_group_z; uint32_t num_batches;
|
||||
uint32_t k_split;
|
||||
uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
|
||||
uint32_t padded_N;
|
||||
|
|
@ -963,6 +964,7 @@ struct vk_mat_vec_push_constants {
|
|||
uint32_t batch_stride_b;
|
||||
uint32_t batch_stride_d;
|
||||
uint32_t fusion_flags;
|
||||
uint32_t base_work_group_y;
|
||||
uint32_t ne02;
|
||||
uint32_t ne12;
|
||||
uint32_t broadcast2;
|
||||
|
|
@ -6773,8 +6775,16 @@ static void ggml_vk_matmul(
|
|||
uint32_t padded_n) {
|
||||
VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")");
|
||||
if (split_k == 1) {
|
||||
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, batch });
|
||||
ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2]));
|
||||
|
||||
uint32_t base_work_group_z = 0;
|
||||
while (base_work_group_z < batch) {
|
||||
uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
|
||||
|
||||
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k, ne02, ne12, broadcast2, broadcast3, padded_n };
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, groups_z });
|
||||
base_work_group_z += groups_z;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -6788,9 +6798,17 @@ static void ggml_vk_matmul(
|
|||
uint32_t k_split = CEIL_DIV(k, split_k);
|
||||
k_split = ROUNDUP_POW2(k_split, 256);
|
||||
|
||||
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
|
||||
// Make sure enough workgroups get assigned for split k to work
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
|
||||
ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2]));
|
||||
|
||||
uint32_t base_work_group_z = 0;
|
||||
while (base_work_group_z < batch) {
|
||||
uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
|
||||
|
||||
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
|
||||
// Make sure enough workgroups get assigned for split k to work
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, groups_z });
|
||||
base_work_group_z += groups_z;
|
||||
}
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k };
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 });
|
||||
|
|
@ -7186,7 +7204,6 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|||
}
|
||||
|
||||
// Request descriptor sets
|
||||
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
||||
if (qx_needs_dequant) {
|
||||
ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);
|
||||
}
|
||||
|
|
@ -7484,7 +7501,6 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|||
if (quantize_y) {
|
||||
ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
|
||||
}
|
||||
ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
|
||||
}
|
||||
|
||||
vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
|
||||
|
|
@ -7579,22 +7595,29 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|||
fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;
|
||||
}
|
||||
|
||||
// compute
|
||||
const vk_mat_vec_push_constants pc = {
|
||||
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
|
||||
stride_batch_x, stride_batch_y, stride_batch_d,
|
||||
fusion_flags,
|
||||
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
|
||||
};
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
|
||||
{
|
||||
d_X,
|
||||
d_Y,
|
||||
d_D,
|
||||
d_F0,
|
||||
d_F1,
|
||||
},
|
||||
pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
|
||||
ggml_pipeline_request_descriptor_sets(ctx, dmmv, CEIL_DIV(ne12 * ne13, ctx->device->properties.limits.maxComputeWorkGroupCount[1]));
|
||||
|
||||
uint32_t base_work_group_y = 0;
|
||||
while (base_work_group_y < ne12 * ne13) {
|
||||
|
||||
uint32_t groups_y = std::min((uint32_t)(ne12 * ne13) - base_work_group_y, ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
|
||||
const vk_mat_vec_push_constants pc = {
|
||||
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
|
||||
stride_batch_x, stride_batch_y, stride_batch_d,
|
||||
fusion_flags, base_work_group_y,
|
||||
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
|
||||
};
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
|
||||
{
|
||||
d_X,
|
||||
d_Y,
|
||||
d_D,
|
||||
d_F0,
|
||||
d_F1,
|
||||
},
|
||||
pc, { groups_x, groups_y, groups_z });
|
||||
base_work_group_y += groups_y;
|
||||
}
|
||||
|
||||
if (x_non_contig) {
|
||||
ctx->prealloc_x_need_sync = true;
|
||||
|
|
@ -7832,10 +7855,15 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
|||
src1->nb[2] <= src1->nb[1] &&
|
||||
src1->nb[1] <= src1->nb[3] &&
|
||||
src0->ne[3] == 1 &&
|
||||
src1->ne[3] == 1) {
|
||||
src1->ne[3] == 1 &&
|
||||
src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
|
||||
src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) {
|
||||
ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, cgraph, node_idx);
|
||||
} else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 &&
|
||||
!ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
|
||||
!ggml_is_permuted(src0) && !ggml_is_permuted(src1) &&
|
||||
src0->ne[3] <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&
|
||||
src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
|
||||
src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) {
|
||||
ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, cgraph, node_idx);
|
||||
// mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
|
||||
// when ne12 and ne13 are one.
|
||||
|
|
@ -11560,7 +11588,6 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
|
|||
}
|
||||
}
|
||||
|
||||
ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
|
||||
if (split_k > 1) {
|
||||
ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
|
||||
|
||||
|
|
@ -12069,7 +12096,6 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
|
|||
// y[i] = i % k;
|
||||
}
|
||||
|
||||
ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
|
||||
if (split_k > 1) {
|
||||
ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
|
||||
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ layout (push_constant) uniform parameter
|
|||
uint expert_i1;
|
||||
uint nbi1;
|
||||
#else
|
||||
uint base_work_group_y;
|
||||
uint ne02;
|
||||
uint ne12;
|
||||
uint broadcast2;
|
||||
|
|
@ -45,9 +46,9 @@ uint expert_id;
|
|||
|
||||
void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
|
||||
#ifdef MUL_MAT_ID
|
||||
const uint expert_i0 = gl_GlobalInvocationID.y;
|
||||
const uint expert_i0 = gl_WorkGroupID.y;
|
||||
#else
|
||||
const uint batch_idx = gl_GlobalInvocationID.y;
|
||||
const uint batch_idx = gl_WorkGroupID.y + p.base_work_group_y;
|
||||
#endif
|
||||
|
||||
#ifndef MUL_MAT_ID
|
||||
|
|
|
|||
|
|
@ -90,6 +90,8 @@ layout (push_constant) uniform parameter
|
|||
uint nbi1;
|
||||
uint ne11;
|
||||
#else
|
||||
uint base_work_group_z;
|
||||
uint num_batches;
|
||||
uint k_split;
|
||||
uint ne02;
|
||||
uint ne12;
|
||||
|
|
@ -139,7 +141,7 @@ void main() {
|
|||
const uint ic = gl_WorkGroupID.y;
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
const uint expert_idx = gl_GlobalInvocationID.z;
|
||||
const uint expert_idx = gl_WorkGroupID.z;
|
||||
if (ic * BN >= data_expert_count[expert_idx]) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -149,7 +151,7 @@ void main() {
|
|||
#endif
|
||||
|
||||
#ifndef MUL_MAT_ID
|
||||
const uint batch_idx = gl_GlobalInvocationID.z;
|
||||
const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z;
|
||||
|
||||
const uint i13 = batch_idx / p.ne12;
|
||||
const uint i12 = batch_idx % p.ne12;
|
||||
|
|
@ -366,7 +368,7 @@ void main() {
|
|||
const uint dc = ic * BN + warp_c * WN;
|
||||
|
||||
#ifndef MUL_MAT_ID
|
||||
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
|
||||
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches;
|
||||
#endif
|
||||
|
||||
#ifdef COOPMAT
|
||||
|
|
|
|||
|
|
@ -53,6 +53,8 @@ layout (push_constant) uniform parameter
|
|||
uint nbi1;
|
||||
uint ne11;
|
||||
#else
|
||||
uint base_work_group_z;
|
||||
uint num_batches;
|
||||
uint k_split;
|
||||
uint ne02;
|
||||
uint ne12;
|
||||
|
|
@ -197,7 +199,7 @@ void main() {
|
|||
const uint ic = gl_WorkGroupID.y;
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
const uint expert_idx = gl_GlobalInvocationID.z;
|
||||
const uint expert_idx = gl_WorkGroupID.z;
|
||||
if (ic * BN >= data_expert_count[expert_idx]) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -215,7 +217,7 @@ void main() {
|
|||
#endif
|
||||
|
||||
#ifndef MUL_MAT_ID
|
||||
const uint batch_idx = gl_GlobalInvocationID.z;
|
||||
const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z;
|
||||
|
||||
const uint i13 = batch_idx / p.ne12;
|
||||
const uint i12 = batch_idx % p.ne12;
|
||||
|
|
@ -255,7 +257,7 @@ void main() {
|
|||
#else
|
||||
uint pos_a = batch_idx_a * (p.batch_stride_a / QUANT_K);
|
||||
uint pos_b = batch_idx * p.batch_stride_b;
|
||||
uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
|
||||
uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches;
|
||||
#endif
|
||||
|
||||
uint stride_a = p.stride_a / QUANT_K;
|
||||
|
|
|
|||
Loading…
Reference in New Issue