CUDA: use shared mem for ssm_conv (#20128)

* CUDA: use shared mem for ssm_conv

* fuse silu + ssm_conv

* fuse unary + mul

* enable for fp16

* formatting

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
Aman Gupta 2026-03-06 23:09:59 +08:00 committed by GitHub
parent 388baabc06
commit 1e38a7a6fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 168 additions and 33 deletions

View File

@ -3348,6 +3348,46 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
return true;
}
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_SSM_CONV && ops.begin()[1] == GGML_OP_UNARY
&& unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) {
const ggml_tensor * ssm_conv = cgraph->nodes[node_idx];
const ggml_tensor * silu = cgraph->nodes[node_idx+1];
if (ssm_conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) {
return false;
}
return true;
}
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_UNARY && ops.begin()[1] == GGML_OP_MUL
&& unary_ops.size() == 1 && (unary_ops.begin()[0] == GGML_UNARY_OP_SILU || unary_ops.begin()[0] == GGML_UNARY_OP_SIGMOID || unary_ops.begin()[0] == GGML_UNARY_OP_SOFTPLUS)) {
const ggml_tensor * unary = cgraph->nodes[node_idx];
const ggml_tensor * mul = cgraph->nodes[node_idx+1];
if (ggml_get_unary_op(unary) != unary_ops.begin()[0]) {
return false;
}
if (unary->type != GGML_TYPE_F32 && unary->type != GGML_TYPE_F16) {
return false;
}
if (unary->type != mul->type) {
return false;
}
const ggml_tensor * other = (mul->src[0] == unary) ? mul->src[1] : mul->src[0];
if (other->type != unary->type) {
return false;
}
if (!ggml_is_contiguous_1(other) || !ggml_is_contiguous_1(unary->src[0]) || !ggml_are_same_shape(other, unary)) {
return false;
}
return true;
}
if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE
&& unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) {
const ggml_tensor *scale = cgraph->nodes[node_idx];
@ -3836,6 +3876,20 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
continue;
}
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) {
ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i+1]);
i++;
continue;
}
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SILU }) ||
ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SIGMOID }) ||
ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SOFTPLUS })) {
ggml_cuda_op_unary_mul(*cuda_ctx, node, cgraph->nodes[i+1]);
i++;
continue;
}
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
i += 2;
ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node);

View File

@ -1,6 +1,7 @@
#include "ssm-conv.cuh"
#include "unary.cuh"
template <size_t split_d_inner, size_t d_conv>
template <bool apply_silu, size_t split_d_inner, size_t d_conv>
static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1,
const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1,
float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2,
@ -41,11 +42,11 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float
for (size_t j = 0; j < d_conv; j++) {
sumf += x[(i + j) % d_conv] * w[j];
}
y_block[i * stride_y + tid] = sumf;
y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf;
}
}
template <size_t split_d_inner, size_t d_conv, int64_t split_n_t>
template <bool apply_silu, size_t split_d_inner, size_t d_conv, int64_t split_n_t>
static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1,
const int src0_nb0, const int src0_nb1, const int src0_nb2,
const int src1_nb1, float * __restrict__ dst, const int dst_nb0,
@ -65,36 +66,46 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
const int stride_w = src1_nb1 / sizeof(float);
const int stride_y = dst_nb1 / sizeof(float);
float x[d_conv] = { 0.0f };
float w[d_conv] = { 0.0f };
const int64_t local_n_t = min(split_n_t, n_t - bidz * split_n_t);
const int n_cols = d_conv - 1 + split_n_t;
extern __shared__ float smem[];
constexpr int load_cols = d_conv - 1 + split_n_t;
constexpr int total_elems = split_d_inner * load_cols;
int row = tid / load_cols;
int col = tid % load_cols;
#pragma unroll
for (int idx = tid; idx < total_elems; idx += split_d_inner) {
if (row < (int)split_d_inner) {
smem[row * n_cols + col] = x_block[row * stride_x + col];
}
col += split_d_inner;
row += col / load_cols;
col = col % load_cols;
}
__syncthreads();
// Load weights into registers (done once, small)
float w[d_conv] = { 0.0f };
#pragma unroll
for (size_t j = 0; j < d_conv; j++) {
w[j] = w_block[tid * stride_w + j];
}
// Compute from shared memory
for (int64_t i = 0; i < local_n_t; i++) {
float sumf = 0.0f;
#pragma unroll
for (int64_t i = 0; i < split_n_t; i++) {
if (bidz * split_n_t + i < n_t) {
float sumf = 0.0f;
if (i == 0) {
for (size_t j = 0; j < d_conv; j++) {
x[j] = x_block[tid * stride_x + j];
}
} else {
x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1];
}
#pragma unroll
for (size_t j = 0; j < d_conv; j++) {
sumf += x[(i + j) % d_conv] * w[j];
}
y_block[i * stride_y + tid] = sumf;
for (size_t j = 0; j < d_conv; j++) {
sumf += smem[tid * n_cols + i + j] * w[j];
}
y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf;
}
}
template <bool apply_silu>
static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1,
const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1,
const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t,
@ -106,12 +117,13 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
constexpr int kNC = decltype(NC)::value;
if (n_t <= 32) {
const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
ssm_conv_f32<threads, kNC><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
ssm_conv_f32<apply_silu, threads, kNC><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
dst, dst_nb0, dst_nb1, dst_nb2, n_t);
} else {
const int64_t split_n_t = 32;
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
ssm_conv_long_token_f32<threads, kNC, split_n_t><<<blocks, threads, 0, stream>>>(
const size_t smem_size = threads * (kNC - 1 + split_n_t) * sizeof(float);
ssm_conv_long_token_f32<apply_silu, threads, kNC, split_n_t><<<blocks, threads, smem_size, stream>>>(
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
}
};
@ -124,27 +136,36 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
}
}
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst) {
const struct ggml_tensor * src0 = dst->src[0]; // conv_x
const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight
const bool fuse_silu = silu_dst != nullptr;
// When fusing, write to silu_dst (the node downstream references).
const struct ggml_tensor * out = fuse_silu ? silu_dst : dst;
const int64_t nc = src1->ne[0]; // d_conv
const int64_t nr = src0->ne[1]; // d_inner
const int64_t n_t = dst->ne[1]; // tokens per sequence
const int64_t n_s = dst->ne[2]; // number of sequences in the batch
const int64_t n_t = out->ne[1]; // tokens per sequence
const int64_t n_s = out->ne[2]; // number of sequences in the batch
GGML_ASSERT(dst->ne[0] == nr);
GGML_ASSERT(out->ne[0] == nr);
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
const float * src0_d = (const float *) src0->data;
const float * src1_d = (const float *) src1->data;
float * dst_d = (float *) dst->data;
float * dst_d = (float *) out->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, dst->nb[0], dst->nb[1],
dst->nb[2], nc, nr, n_t, n_s, stream);
GGML_ASSERT(out->type == GGML_TYPE_F32);
if (fuse_silu) {
ssm_conv_f32_cuda<true>(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1],
out->nb[2], nc, nr, n_t, n_s, stream);
} else {
ssm_conv_f32_cuda<false>(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1],
out->nb[2], nc, nr, n_t, n_s, stream);
}
}

View File

@ -1,3 +1,3 @@
#include "common.cuh"
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst = nullptr);

View File

@ -560,3 +560,58 @@ void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
leaky_relu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), negative_slope, stream);
}
}
/* fused unary + mul */
template <float (*op)(float)>
static void ggml_cuda_op_unary_mul_impl(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node) {
// unary_node: UNARY op applied to unary_node->src[0]
// mul_node: MUL(a, b) where one of a/b is unary_node
// Output goes to mul_node->data
const ggml_tensor * unary_src = unary_node->src[0]; // input to the unary op
const ggml_tensor * other_src = (mul_node->src[0] == unary_node) ? mul_node->src[1] : mul_node->src[0];
GGML_ASSERT(ggml_is_contiguous_1(unary_src));
GGML_ASSERT(unary_src->nb[0] == ggml_element_size(unary_src));
GGML_ASSERT(ggml_is_contiguous_1(other_src));
GGML_ASSERT(other_src->nb[0] == ggml_element_size(other_src));
GGML_ASSERT(ggml_are_same_shape(unary_src, other_src));
GGML_ASSERT(unary_src->type == GGML_TYPE_F32 || unary_src->type == GGML_TYPE_F16);
GGML_ASSERT(unary_src->type == other_src->type);
GGML_ASSERT(unary_src->type == mul_node->type);
cudaStream_t stream = ctx.stream();
const int64_t k = ggml_nelements(mul_node);
const int64_t nc = unary_src->ne[0];
const int64_t unary_stride = unary_src->nb[1];
const int64_t other_stride = other_src->nb[1];
if (unary_src->type == GGML_TYPE_F16) {
unary_gated_cuda<op>((const half *) unary_src->data, (const half *) other_src->data,
(half *) mul_node->data, k, nc,
unary_stride / sizeof(half), other_stride / sizeof(half), stream);
} else {
unary_gated_cuda<op>((const float *) unary_src->data, (const float *) other_src->data,
(float *) mul_node->data, k, nc,
unary_stride / sizeof(float), other_stride / sizeof(float), stream);
}
}
void ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node) {
switch (ggml_get_unary_op(unary_node)) {
case GGML_UNARY_OP_SILU:
ggml_cuda_op_unary_mul_impl<op_silu>(ctx, unary_node, mul_node);
break;
case GGML_UNARY_OP_SIGMOID:
ggml_cuda_op_unary_mul_impl<op_sigmoid>(ctx, unary_node, mul_node);
break;
case GGML_UNARY_OP_SOFTPLUS:
ggml_cuda_op_unary_mul_impl<op_softplus>(ctx, unary_node, mul_node);
break;
default:
GGML_ABORT("Unsupported unary op for fused unary+mul");
}
}

View File

@ -89,6 +89,8 @@ void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst
void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node);
__device__ __forceinline__ float ggml_cuda_op_silu_single(float x) {
return x / (1.0f + expf(-x));
}

View File

@ -7663,6 +7663,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}));
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {2 * d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}));
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv, d_inner, 4, 1}, {d_conv, d_inner, 1, 1}));
// long token (n_t > 32, exercises the long_token kernel path)
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv - 1 + 64, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}));
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv - 1 + 64, d_inner, 4, 1}, {d_conv, d_inner, 1, 1}));
}
}