Fix perf issue on ampere. Use mmvf mm-id only for non-nvidia GPUs

This commit is contained in:
Aman Gupta 2026-01-23 16:47:49 +01:00
parent dff9128825
commit 378ab5fd90
2 changed files with 18 additions and 8 deletions

View File

@ -2282,11 +2282,16 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE);
if (ne2 <= MMVQ_MAX_BATCH_SIZE) {
if (ggml_is_quantized(src0->type)) {
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
if (ne2 <= 4) {
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
return;
}
} else {
ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst);
if (GGML_CUDA_CC_IS_AMD(cc)) {
ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst);
return;
}
}
return;
}
if (ggml_cuda_should_use_mmq(src0->type, cc, ne12, /*n_experts=*/ne02)) {

View File

@ -165,10 +165,15 @@ static __global__ void mul_mat_vec_q(
// for MUL_MAT_ID - blockIdx.y = n_expert_used, blockIdx.z = ncols_dst (tokens)
const uint32_t channel_dst = blockIdx.y;
const uint32_t token_idx = ids ? blockIdx.z : 0;
const uint32_t channel_x = ids ? ids[blockIdx.y + token_idx * ids_stride] : fastdiv(channel_dst, channel_ratio);
const uint32_t channel_y = ids ? fastmodulo(blockIdx.y, nchannels_y) : channel_dst;
const uint32_t sample_dst = ids ? 0 : blockIdx.z;
const uint32_t token_idx = blockIdx.z;
const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst + token_idx * ids_stride] : fastdiv(channel_dst, channel_ratio);
const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
uint32_t sample_dst = blockIdx.z;
if constexpr (ncols_dst == 1) {
sample_dst *= !ids_stride; // sample_dst for ids is 0
}
const uint32_t sample_x = fastdiv(sample_dst, sample_ratio);
const uint32_t sample_y = sample_dst;
@ -190,11 +195,11 @@ static __global__ void mul_mat_vec_q(
active_glu = fusion.glu_op;
}
const uint32_t channel_bias = ids ? channel_x : channel_dst;
float x_biases[ncols_dst] = { 0.0f };
float gate_biases[ncols_dst] = { 0.0f };
if constexpr (has_fusion) {
const uint32_t channel_bias = ids ? channel_x : channel_dst;
if (use_bias) {
x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
// 1. Hide latency by prefetching bias and gate here