Add support for CUMSUM and TRI for CUDA. (#17584)

* Add support for CUMSUM and TRI for CUDA.

* Minor optimizations.

* Correct warp_prefix_inclusive_sum in float2 variant to return float2

* Optimize TRI

* Whitespace

* Fix strides.

* Implement double loop

* Whitespace

* Fix HIP compilation bugs

* Optimizations + big case performance tests

* Implement using CUB with fallback to custom kernel

* Remove error message.

* Fixes from code review

* Comment out CPU-unsupported F16/BF16 cases to fix CI

* Fine, you win :P

* Fix last cast, use NO_DEVICE_CODE and GGML_UNUSED_VARS

* Vary warp-size based on physical warp size

* Add GGML_UNUSED_VARS in tri as well

* Use constexpr and call prefix_inclusive with warp_size template param

* Update ggml/src/ggml-cuda/cumsum.cu

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Apply suggestions from code review

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Change to tid % warp_size

* Fix strides; hardcode mask; add ggml_lane_mask_t

* Missing renames, remove unused get_warp_mask(), explicit calls to ggml_cuda_info()

* Too hasty...

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
Piotr Wilkin (ilintar) 2025-12-04 22:19:51 +01:00 committed by GitHub
parent bde188d60f
commit 96fe9badfc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 448 additions and 0 deletions

View File

@ -463,6 +463,53 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
return x;
}
template<typename T, int width = WARP_SIZE>
static __device__ __forceinline__ T warp_prefix_inclusive_sum(T x) {
const int lane_id = threadIdx.x % width;
#pragma unroll
for (int offset = 1; offset < width; offset <<= 1) {
const T t = __shfl_up_sync(0xffffffff, x, offset, width);
if (lane_id >= offset) {
x += t;
}
}
return x;
}
template<int width = WARP_SIZE>
static __device__ __forceinline__ float2 warp_prefix_inclusive_sum(float2 a) {
const int lane_id = threadIdx.x % width;
#pragma unroll
for (int offset = 1; offset < width; offset <<= 1) {
const float t_x = __shfl_up_sync(0xffffffff, a.x, offset, width);
const float t_y = __shfl_up_sync(0xffffffff, a.y, offset, width);
if (lane_id >= offset) {
a.x += t_x;
a.y += t_y;
}
}
return a;
}
template<int width = WARP_SIZE>
static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) {
#ifdef FP16_AVAILABLE
const int lane_id = threadIdx.x % width;
#pragma unroll
for (int offset = 1; offset < width; offset <<= 1) {
const half2 t = __shfl_up_sync(0xffffffff, a, offset, width);
if (lane_id >= offset) {
a = __hadd2(a, t);
}
}
return a;
#else
NO_DEVICE_CODE;
return a;
#endif // FP16_AVAILABLE
}
static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
#ifdef FP16_AVAILABLE

View File

@ -0,0 +1,237 @@
#include <algorithm>
#include "cumsum.cuh"
#include "convert.cuh"
#include "ggml-cuda/common.cuh"
#include "ggml.h"
#ifdef GGML_CUDA_USE_CUB
# include <cub/device/device_scan.cuh>
#endif // GGML_CUDA_USE_CUB
template<typename T, int BLOCK_SIZE>
static __global__ void cumsum_cub_kernel(
const T * __restrict__ src,
T * __restrict__ dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t s1, const int64_t s2, const int64_t s3) {
#ifdef GGML_CUDA_USE_CUB
using BlockScan = cub::BlockScan<T, BLOCK_SIZE>;
__shared__ typename BlockScan::TempStorage temp_storage;
__shared__ T block_carry; // carry from previous tile
const int tid = threadIdx.x;
const int64_t i1 = blockIdx.x;
const int64_t i2 = blockIdx.y;
const int64_t i3 = blockIdx.z;
if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) {
return;
}
const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3;
if (tid == 0) {
block_carry = 0;
}
__syncthreads();
for (int64_t start = 0; start < ne00; start += BLOCK_SIZE) {
int64_t idx = start + tid;
T x = (idx < ne00) ? src_row[idx] : T(0);
T inclusive;
T block_total;
BlockScan(temp_storage).InclusiveSum(x, inclusive, block_total);
__syncthreads();
T final_val = inclusive + block_carry;
// store result
if (idx < ne00) {
dst_row[idx] = final_val;
}
__syncthreads();
if (tid == 0) {
block_carry += block_total;
}
__syncthreads();
}
#else
NO_DEVICE_CODE;
#endif // GGML_CUDA_USE_CUB
}
// Fallback kernel implementation (original)
template<typename T>
static __global__ void cumsum_kernel(
const T * src, T * dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t s00, const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t s0, const int64_t s1, const int64_t s2, const int64_t s3) {
GGML_UNUSED_VARS(s00, s0);
const int tid = threadIdx.x;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
const int lane = tid % warp_size;
const int warp = tid / warp_size;
const int warps_per_block = blockDim.x / warp_size;
extern __shared__ float smem[];
float * s_vals = smem;
float * s_warp_sums = smem + blockDim.x;
float * s_carry = smem + blockDim.x + warps_per_block;
float * s_chunk_total = s_carry + 1;
// Initialize carry
if (tid == 0) {
*s_carry = 0.0f;
}
__syncthreads();
const int64_t i3 = blockIdx.z;
const int64_t i2 = blockIdx.y;
const int64_t i1 = blockIdx.x;
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
return;
}
const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3;
for (int64_t start = 0; start < ne00; start += blockDim.x) {
int64_t idx = start + tid;
float val = (idx < ne00) ? ggml_cuda_cast<float, T>(src_row[idx]) : 0.0f;
// 1. Warp inclusive scan
val = warp_prefix_inclusive_sum<T, warp_size>(val);
s_vals[tid] = val;
// Store warp total
if (lane == warp_size - 1) {
s_warp_sums[warp] = val;
}
__syncthreads();
// 2. Exclusive scan of warp sums (warp 0 only)
if (warp == 0) {
float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f;
float inc = warp_prefix_inclusive_sum<T, warp_size>(w);
if (tid < warps_per_block) {
s_warp_sums[tid] = inc - w; // exclusive sum
}
if (tid == warps_per_block - 1) {
*s_chunk_total = inc; // total sum of this chunk
}
}
__syncthreads();
float carry = *s_carry;
float final_val = s_vals[tid] + s_warp_sums[warp] + carry;
if (idx < ne00) {
dst_row[idx] = ggml_cuda_cast<T, float>(final_val);
}
__syncthreads();
// Update carry for next chunk
if (tid == 0) {
*s_carry += *s_chunk_total;
}
__syncthreads();
}
}
template<typename T>
static void cumsum_cuda(
const T * src, T * dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
cudaStream_t stream) {
const size_t type_size = sizeof(T);
bool use_cub = false;
#ifdef GGML_CUDA_USE_CUB
// Check if we can use CUB (data must be contiguous along innermost dimension)
const bool is_contiguous = (nb00 == type_size) && (nb0 == type_size);
if (is_contiguous) {
use_cub = true;
}
#endif // GGML_CUDA_USE_CUB
dim3 grid_dims(ne01, ne02, ne03);
const auto &info = ggml_cuda_info().devices[ggml_cuda_get_device()];
const int warp_size = info.warp_size;
const int num_warps = (ne00 + warp_size - 1) / warp_size;
int block_size = num_warps * warp_size;
block_size = std::min(block_size, CUDA_CUMSUM_BLOCK_SIZE);
dim3 block_dims(block_size, 1, 1);
const int warps_per_block = block_size / warp_size;
const size_t shmem_size = (block_size + warps_per_block + 2) * sizeof(float);
if (use_cub) {
cumsum_cub_kernel<T, CUDA_CUMSUM_BLOCK_SIZE><<<grid_dims, CUDA_CUMSUM_BLOCK_SIZE, 0, stream>>>(
src, dst,
ne00, ne01, ne02, ne03,
nb01 / type_size, nb02 / type_size, nb03 / type_size,
nb1 / type_size, nb2 / type_size, nb3 / type_size
);
} else {
cumsum_kernel<<<grid_dims, block_dims, shmem_size, stream>>>(
src, dst,
ne00, ne01, ne02, ne03,
nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
);
}
}
void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == dst->type);
switch(src0->type) {
case GGML_TYPE_F32:
{
cumsum_cuda(
(const float *)src0->data, (float *)dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
stream
);
} break;
// We do not support those on CPU for now anyway, so comment them out because they cause errors on some CI platforms
/*case GGML_TYPE_F16:
{
cumsum_cuda(
(const half *)src0->data, (half *)dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
stream
);
} break;
case GGML_TYPE_BF16:
{
cumsum_cuda(
(const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
stream
);
} break;*/
default:
GGML_ABORT("fatal error");
}
}

View File

@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_CUMSUM_BLOCK_SIZE 256
void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -54,6 +54,8 @@
#include "ggml-cuda/set-rows.cuh"
#include "ggml-cuda/pad_reflect_1d.cuh"
#include "ggml-cuda/solve_tri.cuh"
#include "ggml-cuda/tri.cuh"
#include "ggml-cuda/cumsum.cuh"
#include "ggml.h"
#include <algorithm>
@ -2701,6 +2703,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_CROSS_ENTROPY_LOSS:
ggml_cuda_cross_entropy_loss(ctx, dst);
break;
case GGML_OP_CUMSUM:
ggml_cuda_op_cumsum(ctx, dst);
break;
case GGML_OP_TRI:
ggml_cuda_op_tri(ctx, dst);
break;
case GGML_OP_RWKV_WKV6:
ggml_cuda_op_rwkv_wkv6(ctx, dst);
break;
@ -4609,6 +4617,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
case GGML_OP_CUMSUM:
case GGML_OP_TRI:
return true;
case GGML_OP_SOLVE_TRI:
return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32;

136
ggml/src/ggml-cuda/tri.cu Normal file
View File

@ -0,0 +1,136 @@
#include "common.cuh"
#include "convert.cuh"
#include "tri.cuh"
#include "ggml.h"
template<typename T, bool prefix_keep, int add_to_split>
static __global__ void tri_kernel(
const T * src, T * dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) {
const int64_t i3 = blockIdx.z;
const int64_t i2 = blockIdx.y;
const int64_t i1 = blockIdx.x;
const int64_t split_point = i1 + add_to_split;
GGML_UNUSED_VARS(nb00, nb0);
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
return;
}
const T * src_row = src + i1*nb01 + i2*nb02 + i3*nb03;
T * dst_row = dst + i1*nb1 + i2*nb2 + i3*nb3;
if constexpr (prefix_keep) {
for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) {
dst_row[i0] = src_row[i0];
}
for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) {
dst_row[i0] = ggml_cuda_cast<T, float>(0.0f);
}
} else {
for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) {
dst_row[i0] = ggml_cuda_cast<T, float>(0.0f);
}
for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) {
dst_row[i0] = src_row[i0];
}
}
}
template<typename T>
static void tri_cuda(
const T * src, T * dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
const ggml_tri_type ttype,
cudaStream_t stream) {
dim3 block_dims(CUDA_TRI_BLOCK_SIZE, 1, 1);
dim3 grid_dims(ne01, ne02, ne03);
const size_t type_size = sizeof(T);
const int add_to_split = (ttype == GGML_TRI_TYPE_LOWER_DIAG || ttype == GGML_TRI_TYPE_UPPER) ? 1 : 0;
const bool prefix_keep = (ttype == GGML_TRI_TYPE_LOWER || ttype == GGML_TRI_TYPE_LOWER_DIAG);
if (prefix_keep) {
if (add_to_split == 0) {
tri_kernel<T, true, 0><<<grid_dims, block_dims, 0, stream>>>(
src, dst,
ne00, ne01, ne02, ne03,
nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
);
} else { // only 0 and 1 supported
tri_kernel<T, true, 1><<<grid_dims, block_dims, 0, stream>>>(
src, dst,
ne00, ne01, ne02, ne03,
nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
);
}
} else {
if (add_to_split == 0) {
tri_kernel<T, false, 0><<<grid_dims, block_dims, 0, stream>>>(
src, dst,
ne00, ne01, ne02, ne03,
nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
);
} else {
tri_kernel<T, false, 1><<<grid_dims, block_dims, 0, stream>>>(
src, dst,
ne00, ne01, ne02, ne03,
nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
);
}
}
}
void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
cudaStream_t stream = ctx.stream();
const ggml_tri_type ttype = static_cast<ggml_tri_type>(ggml_get_op_params_i32(dst, 0));
GGML_ASSERT(src0->type == dst->type);
switch(src0->type) {
case GGML_TYPE_F32:
{
tri_cuda(
(const float *)src0->data, (float *)dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
ttype, stream
);
} break;
case GGML_TYPE_F16:
{
tri_cuda(
(const half *)src0->data, (half *)dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
ttype, stream
);
} break;
case GGML_TYPE_BF16:
{
tri_cuda(
(const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
ttype, stream
);
} break;
default:
GGML_ABORT("fatal error");
}
}

View File

@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_TRI_BLOCK_SIZE 256
void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -7725,6 +7725,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 10, 5, 4, 3 }));
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 127, 5, 4, 3 }));
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 5, 4, 3 }));
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 128, 4, 4 }));
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 255, 5, 4, 3 }));
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 256, 5, 4, 3 }));
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 511, 5, 4, 3 }));
@ -7954,6 +7955,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 2 }, { 6, 64, 4, 2 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 8, 128, 4, 1 }));
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_F32, { 256, 256, 4, 4 }));
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG, GGML_TYPE_F32, { 1024, 1024, 8, 4 }));
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 128, 4, 4 }));
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2048, 16, 5, 4 }));
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 20000, 10, 4, 1 }));
for (int bs : {1, 2, 3, 4, 5, 8, 512}) {
for (ggml_type type_a : all_types) {
for (ggml_type type_b : {GGML_TYPE_F32}) {