vulkan: Use mul_mat_vec_id for small values of n (#18918)
Change ggml_vk_mul_mat_vec_id_q_f16 to loop over the batch dimension and update the indexing calculations in get_offsets. Mat-vec is faster than mat-mat for small values of n. We don't get the same reuse of the weights as in the non-ID path, but with this the cost is linear in n rather than n>1 being far slower than n==1.
This commit is contained in:
parent
ad8d85bd94
commit
50b7f076a5
|
|
@ -991,6 +991,8 @@ struct vk_mat_vec_id_push_constants {
|
|||
uint32_t fusion_flags;
|
||||
uint32_t nei0;
|
||||
uint32_t ne11;
|
||||
uint32_t expert_i1;
|
||||
uint32_t nbi1;
|
||||
};
|
||||
|
||||
struct vk_flash_attn_push_constants {
|
||||
|
|
@ -8083,8 +8085,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
|||
|
||||
const uint64_t nei0 = ids->ne[0];
|
||||
const uint64_t nei1 = ids->ne[1];
|
||||
|
||||
GGML_ASSERT(nei1 == 1);
|
||||
const uint32_t nbi1 = (uint32_t)(ids->nb[1] / sizeof(int));
|
||||
|
||||
const uint64_t ne20 = dst->ne[0];
|
||||
const uint64_t ne21 = dst->ne[1];
|
||||
|
|
@ -8168,7 +8169,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
|||
if (quantize_y) {
|
||||
ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
|
||||
}
|
||||
ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
|
||||
ggml_pipeline_request_descriptor_sets(ctx, dmmv, nei1);
|
||||
}
|
||||
|
||||
vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
|
||||
|
|
@ -8226,7 +8227,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
|||
uint32_t stride_batch_y = ne10*ne11;
|
||||
|
||||
if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
|
||||
stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
|
||||
stride_batch_y = src1->nb[2] / ggml_type_size(src1->type);
|
||||
}
|
||||
|
||||
const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];
|
||||
|
|
@ -8262,23 +8263,25 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
|||
fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE1;
|
||||
}
|
||||
|
||||
// compute
|
||||
const vk_mat_vec_id_push_constants pc = {
|
||||
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
|
||||
(uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21),
|
||||
fusion_flags,
|
||||
(uint32_t)nei0, (uint32_t)ne11,
|
||||
};
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
|
||||
{
|
||||
d_X,
|
||||
d_Y,
|
||||
d_D,
|
||||
d_F0,
|
||||
d_F1,
|
||||
d_ids,
|
||||
},
|
||||
pc, { groups_x, (uint32_t)nei0, groups_z });
|
||||
// Loop over the batch dimension
|
||||
for (uint32_t expert_i1 = 0; expert_i1 < nei1; ++expert_i1) {
|
||||
const vk_mat_vec_id_push_constants pc = {
|
||||
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
|
||||
(uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21),
|
||||
fusion_flags,
|
||||
(uint32_t)nei0, (uint32_t)ne11, expert_i1, nbi1
|
||||
};
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
|
||||
{
|
||||
d_X,
|
||||
d_Y,
|
||||
d_D,
|
||||
d_F0,
|
||||
d_F1,
|
||||
d_ids,
|
||||
},
|
||||
pc, { groups_x, (uint32_t)nei0, groups_z });
|
||||
}
|
||||
|
||||
if (x_non_contig) {
|
||||
ctx->prealloc_x_need_sync = true;
|
||||
|
|
@ -8292,7 +8295,7 @@ static bool ggml_vk_use_mul_mat_vec_id(const struct ggml_cgraph * cgraph, int no
|
|||
ggml_tensor * dst = cgraph->nodes[node_idx];
|
||||
ggml_tensor * src0 = dst->src[0];
|
||||
ggml_tensor * src2 = dst->src[2];
|
||||
return src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type));
|
||||
return (src2->ne[1] <= 8) && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type));
|
||||
}
|
||||
|
||||
static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
|
||||
|
|
|
|||
|
|
@ -29,6 +29,8 @@ layout (push_constant) uniform parameter
|
|||
#ifdef MUL_MAT_ID
|
||||
uint nei0;
|
||||
uint ne11;
|
||||
uint expert_i1;
|
||||
uint nbi1;
|
||||
#else
|
||||
uint ne02;
|
||||
uint ne12;
|
||||
|
|
@ -43,7 +45,7 @@ uint expert_id;
|
|||
|
||||
void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
|
||||
#ifdef MUL_MAT_ID
|
||||
const uint expert_idx = gl_GlobalInvocationID.y;
|
||||
const uint expert_i0 = gl_GlobalInvocationID.y;
|
||||
#else
|
||||
const uint batch_idx = gl_GlobalInvocationID.y;
|
||||
#endif
|
||||
|
|
@ -60,7 +62,7 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
|
|||
batch_idx_a = i03 * p.ne02 + i02;
|
||||
}
|
||||
#else
|
||||
expert_id = data_ids[expert_idx];
|
||||
expert_id = data_ids[expert_i0 + p.expert_i1 * p.nbi1];
|
||||
#endif
|
||||
|
||||
a_offset =
|
||||
|
|
@ -71,13 +73,13 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
|
|||
#endif
|
||||
b_offset =
|
||||
#ifdef MUL_MAT_ID
|
||||
(expert_idx % p.ne11) * p.stride_b;
|
||||
(expert_i0 % p.ne11) * p.stride_b + p.expert_i1 * p.batch_stride_b;
|
||||
#else
|
||||
batch_idx * p.batch_stride_b;
|
||||
#endif
|
||||
d_offset =
|
||||
#ifdef MUL_MAT_ID
|
||||
expert_idx * p.stride_d;
|
||||
expert_i0 * p.stride_d + p.expert_i1 * p.batch_stride_d;
|
||||
#else
|
||||
batch_idx * p.batch_stride_d;
|
||||
#endif
|
||||
|
|
@ -103,12 +105,12 @@ void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t
|
|||
temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
|
||||
}
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
|
||||
const uint expert_idx = gl_GlobalInvocationID.y;
|
||||
temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]);
|
||||
const uint expert_i0 = gl_GlobalInvocationID.y;
|
||||
temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_i0]);
|
||||
}
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
|
||||
const uint expert_idx = gl_GlobalInvocationID.y;
|
||||
temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]);
|
||||
const uint expert_i0 = gl_GlobalInvocationID.y;
|
||||
temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_i0]);
|
||||
}
|
||||
#else
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
|
||||
|
|
@ -158,12 +160,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
|
|||
temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
|
||||
}
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
|
||||
const uint expert_idx = gl_GlobalInvocationID.y;
|
||||
temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]);
|
||||
const uint expert_i0 = gl_GlobalInvocationID.y;
|
||||
temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_i0]);
|
||||
}
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
|
||||
const uint expert_idx = gl_GlobalInvocationID.y;
|
||||
temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]);
|
||||
const uint expert_i0 = gl_GlobalInvocationID.y;
|
||||
temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_i0]);
|
||||
}
|
||||
#else
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
|
||||
|
|
@ -203,12 +205,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
|
|||
tmpsh[j][n][0] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
|
||||
}
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
|
||||
const uint expert_idx = gl_GlobalInvocationID.y;
|
||||
tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_idx]);
|
||||
const uint expert_i0 = gl_GlobalInvocationID.y;
|
||||
tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_i0]);
|
||||
}
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
|
||||
const uint expert_idx = gl_GlobalInvocationID.y;
|
||||
tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_idx]);
|
||||
const uint expert_i0 = gl_GlobalInvocationID.y;
|
||||
tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_i0]);
|
||||
}
|
||||
#else
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue