CUDA: Optimize argsort for gpu-based token sampling

Argsort is used for top-k currently. WE optimize argsort by 2 things:

1. Use `DeviceRadixSort` for single-row/sequence to parallelize it
   across our SMs
2. Use `DeviceSegmentedSort` for multi-row/sequence as this is the
   correct entrypoint (the function chooses different execution paths,
   it contains `DeviceSegmentedRadixSort` as one of the paths and will
   choose the best one according to heuristics.
   https://nvidia.github.io/cccl/cub/api/structcub_1_1DeviceSegmentedSort.html#overview

Some perf numbers for a RTX PRO 6000:

On the kernel level, tested with
`GGML_CUDA_DISABLE_GRAPHS=1 ./test-backend-ops -o ARGSORT perf`
Before:
```
  ARGSORT(type=f32,ne=[65000,16,1,1],order=0):                  4130 runs -   359.24 us/run
  ARGSORT(type=f32,ne=[200000,1,1,1],order=0):                  8192 runs -   861.34 us/run
  ARGSORT(type=f32,ne=[200000,16,1,1],order=0):                 1343 runs -  1020.01 us/run
```

After:
```
  ARGSORT(type=f32,ne=[65000,16,1,1],order=0):                  4130 runs -   312.41 us/run
  ARGSORT(type=f32,ne=[200000,1,1,1],order=0):                 16384 runs -    63.48 us/run
  ARGSORT(type=f32,ne=[200000,16,1,1],order=0):                 1343 runs -   874.36 us/run
```

---
On the model level, tested with
`llama-cli -m gpt-oss-20b-mxfp4.gguf -n 200 -p "What is
the Capital of Sweden?" -no-cnv -fa 1 --backend-sampling`

Before:
```
llama_perf_sampler_print:    sampling time =       0.25 ms /   207 runs   (    0.00 ms per token, 824701.20 tokens per second)
llama_perf_context_print:        load time =   18215.58 ms
llama_perf_context_print: prompt eval time =      28.20 ms /     7 tokens (    4.03 ms per token,   248.19 tokens per second)
llama_perf_context_print:        eval time =     714.79 ms /   199 runs   (    3.59 ms per token,   278.40 tokens per second)
llama_perf_context_print:       total time =     857.62 ms /   206 tokens
```

After
```
llama_perf_sampler_print:    sampling time =       0.25 ms /   207 runs   (    0.00 ms per token, 828000.00 tokens per second)
llama_perf_context_print:        load time =   18366.92 ms
llama_perf_context_print: prompt eval time =      35.92 ms /     7 tokens (    5.13 ms per token,   194.87 tokens per second)
llama_perf_context_print:        eval time =     532.79 ms /   199 runs   (    2.68 ms per token,   373.50 tokens per second)
llama_perf_context_print:       total time =     683.65 ms /   206 tokens
```
This commit is contained in:
Oliver Simons 2025-11-18 18:17:44 +01:00
parent 311c1a347f
commit 26be108be8
2 changed files with 37 additions and 14 deletions

View File

@ -49,28 +49,49 @@ static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
size_t temp_storage_bytes = 0;
if (order == GGML_SORT_ORDER_ASC) {
DeviceSegmentedRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols * nrows, nrows, // num items, num segments
d_offsets, d_offsets + 1, 0, sizeof(float) * 8, // all bits
stream);
if (nrows == 1) {
DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols, 0, sizeof(float) * 8, stream);
} else {
DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols * nrows, nrows, // num items, num segments
d_offsets, d_offsets + 1, stream);
}
} else {
DeviceSegmentedRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, 0,
sizeof(float) * 8, stream);
if (nrows == 1) {
DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols, 0, sizeof(float) * 8, stream);
} else {
DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
}
}
ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
void * d_temp_storage = temp_storage_alloc.get();
if (order == GGML_SORT_ORDER_ASC) {
DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, sizeof(float) * 8,
stream);
if (nrows == 1) {
DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols, 0, sizeof(float) * 8, stream);
} else {
DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
}
} else {
DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
0, sizeof(float) * 8, stream);
if (nrows == 1) {
DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols, 0, sizeof(float) * 8, stream);
} else {
DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
stream);
}
}
}
#endif // GGML_CUDA_USE_CUB

View File

@ -7886,6 +7886,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
}
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1}));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {200000, 1, 1, 1}));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {200000, 16, 1, 1}));
return test_cases;
}