CUDA: Fix loop unrolling for BW in mul_mat_q_stream_k_fixup (#19053)
By providing stride_* variables as size_t (i.e., 64-bit) the compiler can
correctly unroll the [two for-loops](557515be1e/ggml/src/ggml-cuda/mmq.cuh (L3789-L3816))
on BW. This gives some perf for prefill/pp phase on BW, while not affecting
other SMs:
| GPU | Model | Test | t/s master | t/s osimons/fix_bw_mmq_fixup_kernel | Speedup |
|:--------------------------------------------------------|:----------------------|:-------|-------------:|--------------------------------------:|----------:|
| NVIDIA RTX 6000 Ada Generation | gpt-oss 20B MXFP4 MoE | pp8096 | 8404.05 | 8375.79 | 1.00 |
| NVIDIA RTX 6000 Ada Generation | llama 3B Q4_K_M | pp8096 | 16148.93 | 16019.60 | 0.99 |
| NVIDIA RTX 6000 Ada Generation | llama 8B Q4_0 | pp8096 | 8008.29 | 7978.80 | 1.00 |
| NVIDIA RTX 6000 Ada Generation | nemotron_h 9B BF16 | pp8096 | 4263.16 | 4248.53 | 1.00 |
| NVIDIA RTX 6000 Ada Generation | nemotron_h 9B Q4_K_M | pp8096 | 5165.11 | 5157.43 | 1.00 |
| NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition | gpt-oss 20B MXFP4 MoE | pp8096 | 12582.80 | 12758.37 | 1.01 |
| NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition | llama 3B Q4_K_M | pp8096 | 16879.10 | 17619.47 | 1.04 |
| NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition | llama 8B Q4_0 | pp8096 | 10649.90 | 10982.65 | 1.03 |
| NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition | nemotron_h 9B BF16 | pp8096 | 7717.73 | 7716.22 | 1.00 |
| NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition | nemotron_h 9B Q4_K_M | pp8096 | 7301.90 | 7370.38 | 1.01 |
This commit is contained in:
parent
e9a859db3c
commit
1f1e57f2bf
|
|
@ -3697,13 +3697,20 @@ static __global__ void mul_mat_q(
|
|||
tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
|
||||
}
|
||||
|
||||
|
||||
template <ggml_type type, int mmq_x, bool need_check>
|
||||
static __global__ void mul_mat_q_stream_k_fixup(
|
||||
const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
|
||||
const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst,
|
||||
const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst,
|
||||
const int ncols_max) {
|
||||
static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst,
|
||||
const int32_t * expert_bounds,
|
||||
float * __restrict__ dst,
|
||||
const float * __restrict__ tmp_last_tile,
|
||||
const int ncols_x,
|
||||
const int nrows_x,
|
||||
const int ncols_dst,
|
||||
const size_t stride_col_dst,
|
||||
const int nchannels_y,
|
||||
const size_t stride_channel_dst,
|
||||
const int nsamples_y,
|
||||
const size_t stride_sample_dst,
|
||||
const int ncols_max) {
|
||||
constexpr int mmq_y = get_mmq_y_device();
|
||||
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
||||
constexpr int ITER_K = get_iter_k(type);
|
||||
|
|
|
|||
Loading…
Reference in New Issue