diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index ded400216f..a635668c51 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2282,11 +2282,16 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE); if (ne2 <= MMVQ_MAX_BATCH_SIZE) { if (ggml_is_quantized(src0->type)) { - ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst); + if (ne2 <= 4) { + ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst); + return; + } } else { - ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst); + if (GGML_CUDA_CC_IS_AMD(cc)) { + ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst); + return; + } } - return; } if (ggml_cuda_should_use_mmq(src0->type, cc, ne12, /*n_experts=*/ne02)) { diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 8d5c4dcd31..b0c03eab3e 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -165,10 +165,15 @@ static __global__ void mul_mat_vec_q( // for MUL_MAT_ID - blockIdx.y = n_expert_used, blockIdx.z = ncols_dst (tokens) const uint32_t channel_dst = blockIdx.y; - const uint32_t token_idx = ids ? blockIdx.z : 0; - const uint32_t channel_x = ids ? ids[blockIdx.y + token_idx * ids_stride] : fastdiv(channel_dst, channel_ratio); - const uint32_t channel_y = ids ? fastmodulo(blockIdx.y, nchannels_y) : channel_dst; - const uint32_t sample_dst = ids ? 0 : blockIdx.z; + const uint32_t token_idx = blockIdx.z; + const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst + token_idx * ids_stride] : fastdiv(channel_dst, channel_ratio); + const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst; + uint32_t sample_dst = blockIdx.z; + + if constexpr (ncols_dst == 1) { + sample_dst *= !ids_stride; // sample_dst for ids is 0 + } + const uint32_t sample_x = fastdiv(sample_dst, sample_ratio); const uint32_t sample_y = sample_dst; @@ -190,11 +195,11 @@ static __global__ void mul_mat_vec_q( active_glu = fusion.glu_op; } - const uint32_t channel_bias = ids ? channel_x : channel_dst; float x_biases[ncols_dst] = { 0.0f }; float gate_biases[ncols_dst] = { 0.0f }; if constexpr (has_fusion) { + const uint32_t channel_bias = ids ? channel_x : channel_dst; if (use_bias) { x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0; // 1. Hide latency by prefetching bias and gate here