diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index b2dcaf42fc..35015bc7f3 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3348,6 +3348,46 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, return true; } + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_SSM_CONV && ops.begin()[1] == GGML_OP_UNARY + && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) { + const ggml_tensor * ssm_conv = cgraph->nodes[node_idx]; + const ggml_tensor * silu = cgraph->nodes[node_idx+1]; + + if (ssm_conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) { + return false; + } + + return true; + } + + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_UNARY && ops.begin()[1] == GGML_OP_MUL + && unary_ops.size() == 1 && (unary_ops.begin()[0] == GGML_UNARY_OP_SILU || unary_ops.begin()[0] == GGML_UNARY_OP_SIGMOID || unary_ops.begin()[0] == GGML_UNARY_OP_SOFTPLUS)) { + const ggml_tensor * unary = cgraph->nodes[node_idx]; + const ggml_tensor * mul = cgraph->nodes[node_idx+1]; + + if (ggml_get_unary_op(unary) != unary_ops.begin()[0]) { + return false; + } + + if (unary->type != GGML_TYPE_F32 && unary->type != GGML_TYPE_F16) { + return false; + } + + if (unary->type != mul->type) { + return false; + } + + const ggml_tensor * other = (mul->src[0] == unary) ? mul->src[1] : mul->src[0]; + if (other->type != unary->type) { + return false; + } + if (!ggml_is_contiguous_1(other) || !ggml_is_contiguous_1(unary->src[0]) || !ggml_are_same_shape(other, unary)) { + return false; + } + + return true; + } + if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) { const ggml_tensor *scale = cgraph->nodes[node_idx]; @@ -3836,6 +3876,20 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud continue; } + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) { + ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i+1]); + i++; + continue; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SILU }) || + ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SIGMOID }) || + ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SOFTPLUS })) { + ggml_cuda_op_unary_mul(*cuda_ctx, node, cgraph->nodes[i+1]); + i++; + continue; + } + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) { i += 2; ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node); diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu index 6d5ea704c6..85e82b5a42 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cu +++ b/ggml/src/ggml-cuda/ssm-conv.cu @@ -1,6 +1,7 @@ #include "ssm-conv.cuh" +#include "unary.cuh" -template +template static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, @@ -41,11 +42,11 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float for (size_t j = 0; j < d_conv; j++) { sumf += x[(i + j) % d_conv] * w[j]; } - y_block[i * stride_y + tid] = sumf; + y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf; } } -template +template static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * __restrict__ dst, const int dst_nb0, @@ -65,36 +66,46 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const int stride_w = src1_nb1 / sizeof(float); const int stride_y = dst_nb1 / sizeof(float); - float x[d_conv] = { 0.0f }; - float w[d_conv] = { 0.0f }; + const int64_t local_n_t = min(split_n_t, n_t - bidz * split_n_t); + const int n_cols = d_conv - 1 + split_n_t; + extern __shared__ float smem[]; + + constexpr int load_cols = d_conv - 1 + split_n_t; + constexpr int total_elems = split_d_inner * load_cols; + int row = tid / load_cols; + int col = tid % load_cols; +#pragma unroll + for (int idx = tid; idx < total_elems; idx += split_d_inner) { + if (row < (int)split_d_inner) { + smem[row * n_cols + col] = x_block[row * stride_x + col]; + } + + col += split_d_inner; + row += col / load_cols; + col = col % load_cols; + } + __syncthreads(); + + // Load weights into registers (done once, small) + float w[d_conv] = { 0.0f }; #pragma unroll for (size_t j = 0; j < d_conv; j++) { w[j] = w_block[tid * stride_w + j]; } + // Compute from shared memory + for (int64_t i = 0; i < local_n_t; i++) { + float sumf = 0.0f; #pragma unroll - for (int64_t i = 0; i < split_n_t; i++) { - if (bidz * split_n_t + i < n_t) { - float sumf = 0.0f; - - if (i == 0) { - for (size_t j = 0; j < d_conv; j++) { - x[j] = x_block[tid * stride_x + j]; - } - } else { - x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1]; - } - -#pragma unroll - for (size_t j = 0; j < d_conv; j++) { - sumf += x[(i + j) % d_conv] * w[j]; - } - y_block[i * stride_y + tid] = sumf; + for (size_t j = 0; j < d_conv; j++) { + sumf += smem[tid * n_cols + i + j] * w[j]; } + y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf; } } +template static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t, @@ -106,12 +117,13 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int constexpr int kNC = decltype(NC)::value; if (n_t <= 32) { const dim3 blocks(n_s, (nr + threads - 1) / threads, 1); - ssm_conv_f32<<>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, + ssm_conv_f32<<>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); } else { const int64_t split_n_t = 32; dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t); - ssm_conv_long_token_f32<<>>( + const size_t smem_size = threads * (kNC - 1 + split_n_t) * sizeof(float); + ssm_conv_long_token_f32<<>>( src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); } }; @@ -124,27 +136,36 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int } } -void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { +void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst) { const struct ggml_tensor * src0 = dst->src[0]; // conv_x const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight + const bool fuse_silu = silu_dst != nullptr; + + // When fusing, write to silu_dst (the node downstream references). + const struct ggml_tensor * out = fuse_silu ? silu_dst : dst; const int64_t nc = src1->ne[0]; // d_conv const int64_t nr = src0->ne[1]; // d_inner - const int64_t n_t = dst->ne[1]; // tokens per sequence - const int64_t n_s = dst->ne[2]; // number of sequences in the batch + const int64_t n_t = out->ne[1]; // tokens per sequence + const int64_t n_s = out->ne[2]; // number of sequences in the batch - GGML_ASSERT(dst->ne[0] == nr); + GGML_ASSERT(out->ne[0] == nr); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float)); const float * src0_d = (const float *) src0->data; const float * src1_d = (const float *) src1->data; - float * dst_d = (float *) dst->data; + float * dst_d = (float *) out->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, dst->nb[0], dst->nb[1], - dst->nb[2], nc, nr, n_t, n_s, stream); + GGML_ASSERT(out->type == GGML_TYPE_F32); + if (fuse_silu) { + ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], + out->nb[2], nc, nr, n_t, n_s, stream); + } else { + ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], + out->nb[2], nc, nr, n_t, n_s, stream); + } } diff --git a/ggml/src/ggml-cuda/ssm-conv.cuh b/ggml/src/ggml-cuda/ssm-conv.cuh index 8e6c1f00bf..f96a1cd248 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cuh +++ b/ggml/src/ggml-cuda/ssm-conv.cuh @@ -1,3 +1,3 @@ #include "common.cuh" -void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst = nullptr); diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index d4866067a4..4ad30fa1f3 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -560,3 +560,58 @@ void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) leaky_relu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), negative_slope, stream); } } + +/* fused unary + mul */ + +template +static void ggml_cuda_op_unary_mul_impl(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node) { + // unary_node: UNARY op applied to unary_node->src[0] + // mul_node: MUL(a, b) where one of a/b is unary_node + // Output goes to mul_node->data + + const ggml_tensor * unary_src = unary_node->src[0]; // input to the unary op + const ggml_tensor * other_src = (mul_node->src[0] == unary_node) ? mul_node->src[1] : mul_node->src[0]; + + GGML_ASSERT(ggml_is_contiguous_1(unary_src)); + GGML_ASSERT(unary_src->nb[0] == ggml_element_size(unary_src)); + GGML_ASSERT(ggml_is_contiguous_1(other_src)); + GGML_ASSERT(other_src->nb[0] == ggml_element_size(other_src)); + GGML_ASSERT(ggml_are_same_shape(unary_src, other_src)); + + GGML_ASSERT(unary_src->type == GGML_TYPE_F32 || unary_src->type == GGML_TYPE_F16); + GGML_ASSERT(unary_src->type == other_src->type); + GGML_ASSERT(unary_src->type == mul_node->type); + + cudaStream_t stream = ctx.stream(); + + const int64_t k = ggml_nelements(mul_node); + const int64_t nc = unary_src->ne[0]; + const int64_t unary_stride = unary_src->nb[1]; + const int64_t other_stride = other_src->nb[1]; + + if (unary_src->type == GGML_TYPE_F16) { + unary_gated_cuda((const half *) unary_src->data, (const half *) other_src->data, + (half *) mul_node->data, k, nc, + unary_stride / sizeof(half), other_stride / sizeof(half), stream); + } else { + unary_gated_cuda((const float *) unary_src->data, (const float *) other_src->data, + (float *) mul_node->data, k, nc, + unary_stride / sizeof(float), other_stride / sizeof(float), stream); + } +} + +void ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node) { + switch (ggml_get_unary_op(unary_node)) { + case GGML_UNARY_OP_SILU: + ggml_cuda_op_unary_mul_impl(ctx, unary_node, mul_node); + break; + case GGML_UNARY_OP_SIGMOID: + ggml_cuda_op_unary_mul_impl(ctx, unary_node, mul_node); + break; + case GGML_UNARY_OP_SOFTPLUS: + ggml_cuda_op_unary_mul_impl(ctx, unary_node, mul_node); + break; + default: + GGML_ABORT("Unsupported unary op for fused unary+mul"); + } +} diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 609046e569..f1dd2183a6 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -89,6 +89,8 @@ void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node); + __device__ __forceinline__ float ggml_cuda_op_silu_single(float x) { return x / (1.0f + expf(-x)); } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 7c6938d447..c4b9540f4f 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7663,6 +7663,9 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1})); test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {2 * d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1})); test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv, d_inner, 4, 1}, {d_conv, d_inner, 1, 1})); + // long token (n_t > 32, exercises the long_token kernel path) + test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv - 1 + 64, d_inner, 1, 1}, {d_conv, d_inner, 1, 1})); + test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv - 1 + 64, d_inner, 4, 1}, {d_conv, d_inner, 1, 1})); } }