ggml : check return value of CUB calls used in argsort and top-k (they all return cudaError_t) (#21676)
Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
This commit is contained in:
parent
c8ac02fa1b
commit
009a113326
|
|
@ -60,24 +60,24 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||||
|
|
||||||
if (order == GGML_SORT_ORDER_ASC) {
|
if (order == GGML_SORT_ORDER_ASC) {
|
||||||
if (nrows == 1) {
|
if (nrows == 1) {
|
||||||
DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
CUDA_CHECK(DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||||
temp_indices, dst, // values (indices)
|
temp_indices, dst, // values (indices)
|
||||||
ncols, 0, sizeof(float) * 8, stream);
|
ncols, 0, sizeof(float) * 8, stream));
|
||||||
} else {
|
} else {
|
||||||
DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
CUDA_CHECK(DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||||
temp_indices, dst, // values (indices)
|
temp_indices, dst, // values (indices)
|
||||||
ncols * nrows, nrows, // num items, num segments
|
ncols * nrows, nrows, // num items, num segments
|
||||||
offset_iterator, offset_iterator + 1, stream);
|
offset_iterator, offset_iterator + 1, stream));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (nrows == 1) {
|
if (nrows == 1) {
|
||||||
DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
CUDA_CHECK(DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||||
temp_indices, dst, // values (indices)
|
temp_indices, dst, // values (indices)
|
||||||
ncols, 0, sizeof(float) * 8, stream);
|
ncols, 0, sizeof(float) * 8, stream));
|
||||||
} else {
|
} else {
|
||||||
DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
|
CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
|
||||||
dst, ncols * nrows, nrows, offset_iterator, offset_iterator + 1,
|
dst, ncols * nrows, nrows, offset_iterator, offset_iterator + 1,
|
||||||
stream);
|
stream));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -86,22 +86,22 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||||
|
|
||||||
if (order == GGML_SORT_ORDER_ASC) {
|
if (order == GGML_SORT_ORDER_ASC) {
|
||||||
if (nrows == 1) {
|
if (nrows == 1) {
|
||||||
DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
CUDA_CHECK(DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||||
temp_indices, dst, // values (indices)
|
temp_indices, dst, // values (indices)
|
||||||
ncols, 0, sizeof(float) * 8, stream);
|
ncols, 0, sizeof(float) * 8, stream));
|
||||||
} else {
|
} else {
|
||||||
DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
|
CUDA_CHECK(DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
|
||||||
ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream);
|
ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (nrows == 1) {
|
if (nrows == 1) {
|
||||||
DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
CUDA_CHECK(DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||||
temp_indices, dst, // values (indices)
|
temp_indices, dst, // values (indices)
|
||||||
ncols, 0, sizeof(float) * 8, stream);
|
ncols, 0, sizeof(float) * 8, stream));
|
||||||
} else {
|
} else {
|
||||||
DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
|
CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
|
||||||
temp_indices, dst, ncols * nrows, nrows, offset_iterator,
|
temp_indices, dst, ncols * nrows, nrows, offset_iterator,
|
||||||
offset_iterator + 1, stream);
|
offset_iterator + 1, stream));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -25,14 +25,14 @@ static void top_k_cub(ggml_cuda_pool & pool,
|
||||||
auto indexes_in = cuda::make_counting_iterator(0);
|
auto indexes_in = cuda::make_counting_iterator(0);
|
||||||
|
|
||||||
size_t temp_storage_bytes = 0;
|
size_t temp_storage_bytes = 0;
|
||||||
DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, ncols, k,
|
CUDA_CHECK(DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, ncols, k,
|
||||||
env);
|
env));
|
||||||
|
|
||||||
ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
|
ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
|
||||||
void * d_temp_storage = temp_storage_alloc.get();
|
void * d_temp_storage = temp_storage_alloc.get();
|
||||||
|
|
||||||
DeviceTopK::MaxPairs(d_temp_storage, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst,
|
CUDA_CHECK(DeviceTopK::MaxPairs(d_temp_storage, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst,
|
||||||
ncols, k, env);
|
ncols, k, env));
|
||||||
}
|
}
|
||||||
|
|
||||||
#elif defined(GGML_CUDA_USE_CUB) // CUB_TOP_K_AVAILABLE
|
#elif defined(GGML_CUDA_USE_CUB) // CUB_TOP_K_AVAILABLE
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue