Fix perf issue on ampere. Use mmvf mm-id only for non-nvidia GPUs
This commit is contained in:
parent
dff9128825
commit
378ab5fd90
|
|
@ -2282,11 +2282,16 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||
static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE);
|
||||
if (ne2 <= MMVQ_MAX_BATCH_SIZE) {
|
||||
if (ggml_is_quantized(src0->type)) {
|
||||
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
|
||||
if (ne2 <= 4) {
|
||||
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst);
|
||||
if (GGML_CUDA_CC_IS_AMD(cc)) {
|
||||
ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst);
|
||||
return;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (ggml_cuda_should_use_mmq(src0->type, cc, ne12, /*n_experts=*/ne02)) {
|
||||
|
|
|
|||
|
|
@ -165,10 +165,15 @@ static __global__ void mul_mat_vec_q(
|
|||
|
||||
// for MUL_MAT_ID - blockIdx.y = n_expert_used, blockIdx.z = ncols_dst (tokens)
|
||||
const uint32_t channel_dst = blockIdx.y;
|
||||
const uint32_t token_idx = ids ? blockIdx.z : 0;
|
||||
const uint32_t channel_x = ids ? ids[blockIdx.y + token_idx * ids_stride] : fastdiv(channel_dst, channel_ratio);
|
||||
const uint32_t channel_y = ids ? fastmodulo(blockIdx.y, nchannels_y) : channel_dst;
|
||||
const uint32_t sample_dst = ids ? 0 : blockIdx.z;
|
||||
const uint32_t token_idx = blockIdx.z;
|
||||
const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst + token_idx * ids_stride] : fastdiv(channel_dst, channel_ratio);
|
||||
const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
|
||||
uint32_t sample_dst = blockIdx.z;
|
||||
|
||||
if constexpr (ncols_dst == 1) {
|
||||
sample_dst *= !ids_stride; // sample_dst for ids is 0
|
||||
}
|
||||
|
||||
const uint32_t sample_x = fastdiv(sample_dst, sample_ratio);
|
||||
const uint32_t sample_y = sample_dst;
|
||||
|
||||
|
|
@ -190,11 +195,11 @@ static __global__ void mul_mat_vec_q(
|
|||
active_glu = fusion.glu_op;
|
||||
}
|
||||
|
||||
const uint32_t channel_bias = ids ? channel_x : channel_dst;
|
||||
|
||||
float x_biases[ncols_dst] = { 0.0f };
|
||||
float gate_biases[ncols_dst] = { 0.0f };
|
||||
if constexpr (has_fusion) {
|
||||
const uint32_t channel_bias = ids ? channel_x : channel_dst;
|
||||
if (use_bias) {
|
||||
x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
|
||||
// 1. Hide latency by prefetching bias and gate here
|
||||
|
|
|
|||
Loading…
Reference in New Issue