CUDA: use shared mem for ssm_conv (#20128)
* CUDA: use shared mem for ssm_conv * fuse silu + ssm_conv * fuse unary + mul * enable for fp16 * formatting Co-authored-by: Johannes Gäßler <johannesg@5d6.de> --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
parent
388baabc06
commit
1e38a7a6fa
|
|
@ -3348,6 +3348,46 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
|
|||
return true;
|
||||
}
|
||||
|
||||
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_SSM_CONV && ops.begin()[1] == GGML_OP_UNARY
|
||||
&& unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) {
|
||||
const ggml_tensor * ssm_conv = cgraph->nodes[node_idx];
|
||||
const ggml_tensor * silu = cgraph->nodes[node_idx+1];
|
||||
|
||||
if (ssm_conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_UNARY && ops.begin()[1] == GGML_OP_MUL
|
||||
&& unary_ops.size() == 1 && (unary_ops.begin()[0] == GGML_UNARY_OP_SILU || unary_ops.begin()[0] == GGML_UNARY_OP_SIGMOID || unary_ops.begin()[0] == GGML_UNARY_OP_SOFTPLUS)) {
|
||||
const ggml_tensor * unary = cgraph->nodes[node_idx];
|
||||
const ggml_tensor * mul = cgraph->nodes[node_idx+1];
|
||||
|
||||
if (ggml_get_unary_op(unary) != unary_ops.begin()[0]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (unary->type != GGML_TYPE_F32 && unary->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (unary->type != mul->type) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const ggml_tensor * other = (mul->src[0] == unary) ? mul->src[1] : mul->src[0];
|
||||
if (other->type != unary->type) {
|
||||
return false;
|
||||
}
|
||||
if (!ggml_is_contiguous_1(other) || !ggml_is_contiguous_1(unary->src[0]) || !ggml_are_same_shape(other, unary)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE
|
||||
&& unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) {
|
||||
const ggml_tensor *scale = cgraph->nodes[node_idx];
|
||||
|
|
@ -3836,6 +3876,20 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
|
|||
continue;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) {
|
||||
ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i+1]);
|
||||
i++;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SILU }) ||
|
||||
ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SIGMOID }) ||
|
||||
ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SOFTPLUS })) {
|
||||
ggml_cuda_op_unary_mul(*cuda_ctx, node, cgraph->nodes[i+1]);
|
||||
i++;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
|
||||
i += 2;
|
||||
ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node);
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#include "ssm-conv.cuh"
|
||||
#include "unary.cuh"
|
||||
|
||||
template <size_t split_d_inner, size_t d_conv>
|
||||
template <bool apply_silu, size_t split_d_inner, size_t d_conv>
|
||||
static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1,
|
||||
const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1,
|
||||
float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2,
|
||||
|
|
@ -41,11 +42,11 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float
|
|||
for (size_t j = 0; j < d_conv; j++) {
|
||||
sumf += x[(i + j) % d_conv] * w[j];
|
||||
}
|
||||
y_block[i * stride_y + tid] = sumf;
|
||||
y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf;
|
||||
}
|
||||
}
|
||||
|
||||
template <size_t split_d_inner, size_t d_conv, int64_t split_n_t>
|
||||
template <bool apply_silu, size_t split_d_inner, size_t d_conv, int64_t split_n_t>
|
||||
static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1,
|
||||
const int src0_nb0, const int src0_nb1, const int src0_nb2,
|
||||
const int src1_nb1, float * __restrict__ dst, const int dst_nb0,
|
||||
|
|
@ -65,36 +66,46 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
|
|||
const int stride_w = src1_nb1 / sizeof(float);
|
||||
const int stride_y = dst_nb1 / sizeof(float);
|
||||
|
||||
float x[d_conv] = { 0.0f };
|
||||
float w[d_conv] = { 0.0f };
|
||||
const int64_t local_n_t = min(split_n_t, n_t - bidz * split_n_t);
|
||||
const int n_cols = d_conv - 1 + split_n_t;
|
||||
|
||||
extern __shared__ float smem[];
|
||||
|
||||
constexpr int load_cols = d_conv - 1 + split_n_t;
|
||||
constexpr int total_elems = split_d_inner * load_cols;
|
||||
int row = tid / load_cols;
|
||||
int col = tid % load_cols;
|
||||
#pragma unroll
|
||||
for (int idx = tid; idx < total_elems; idx += split_d_inner) {
|
||||
if (row < (int)split_d_inner) {
|
||||
smem[row * n_cols + col] = x_block[row * stride_x + col];
|
||||
}
|
||||
|
||||
col += split_d_inner;
|
||||
row += col / load_cols;
|
||||
col = col % load_cols;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Load weights into registers (done once, small)
|
||||
float w[d_conv] = { 0.0f };
|
||||
#pragma unroll
|
||||
for (size_t j = 0; j < d_conv; j++) {
|
||||
w[j] = w_block[tid * stride_w + j];
|
||||
}
|
||||
|
||||
// Compute from shared memory
|
||||
for (int64_t i = 0; i < local_n_t; i++) {
|
||||
float sumf = 0.0f;
|
||||
#pragma unroll
|
||||
for (int64_t i = 0; i < split_n_t; i++) {
|
||||
if (bidz * split_n_t + i < n_t) {
|
||||
float sumf = 0.0f;
|
||||
|
||||
if (i == 0) {
|
||||
for (size_t j = 0; j < d_conv; j++) {
|
||||
x[j] = x_block[tid * stride_x + j];
|
||||
}
|
||||
} else {
|
||||
x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (size_t j = 0; j < d_conv; j++) {
|
||||
sumf += x[(i + j) % d_conv] * w[j];
|
||||
}
|
||||
y_block[i * stride_y + tid] = sumf;
|
||||
for (size_t j = 0; j < d_conv; j++) {
|
||||
sumf += smem[tid * n_cols + i + j] * w[j];
|
||||
}
|
||||
y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool apply_silu>
|
||||
static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1,
|
||||
const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1,
|
||||
const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t,
|
||||
|
|
@ -106,12 +117,13 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
|
|||
constexpr int kNC = decltype(NC)::value;
|
||||
if (n_t <= 32) {
|
||||
const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
|
||||
ssm_conv_f32<threads, kNC><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
|
||||
ssm_conv_f32<apply_silu, threads, kNC><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
|
||||
dst, dst_nb0, dst_nb1, dst_nb2, n_t);
|
||||
} else {
|
||||
const int64_t split_n_t = 32;
|
||||
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
|
||||
ssm_conv_long_token_f32<threads, kNC, split_n_t><<<blocks, threads, 0, stream>>>(
|
||||
const size_t smem_size = threads * (kNC - 1 + split_n_t) * sizeof(float);
|
||||
ssm_conv_long_token_f32<apply_silu, threads, kNC, split_n_t><<<blocks, threads, smem_size, stream>>>(
|
||||
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
|
||||
}
|
||||
};
|
||||
|
|
@ -124,27 +136,36 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
|
|||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst) {
|
||||
const struct ggml_tensor * src0 = dst->src[0]; // conv_x
|
||||
const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight
|
||||
const bool fuse_silu = silu_dst != nullptr;
|
||||
|
||||
// When fusing, write to silu_dst (the node downstream references).
|
||||
const struct ggml_tensor * out = fuse_silu ? silu_dst : dst;
|
||||
|
||||
const int64_t nc = src1->ne[0]; // d_conv
|
||||
const int64_t nr = src0->ne[1]; // d_inner
|
||||
const int64_t n_t = dst->ne[1]; // tokens per sequence
|
||||
const int64_t n_s = dst->ne[2]; // number of sequences in the batch
|
||||
const int64_t n_t = out->ne[1]; // tokens per sequence
|
||||
const int64_t n_s = out->ne[2]; // number of sequences in the batch
|
||||
|
||||
GGML_ASSERT(dst->ne[0] == nr);
|
||||
GGML_ASSERT(out->ne[0] == nr);
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
|
||||
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
const float * src1_d = (const float *) src1->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
float * dst_d = (float *) out->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, dst->nb[0], dst->nb[1],
|
||||
dst->nb[2], nc, nr, n_t, n_s, stream);
|
||||
GGML_ASSERT(out->type == GGML_TYPE_F32);
|
||||
if (fuse_silu) {
|
||||
ssm_conv_f32_cuda<true>(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1],
|
||||
out->nb[2], nc, nr, n_t, n_s, stream);
|
||||
} else {
|
||||
ssm_conv_f32_cuda<false>(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1],
|
||||
out->nb[2], nc, nr, n_t, n_s, stream);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,3 +1,3 @@
|
|||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst = nullptr);
|
||||
|
|
|
|||
|
|
@ -560,3 +560,58 @@ void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
|||
leaky_relu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), negative_slope, stream);
|
||||
}
|
||||
}
|
||||
|
||||
/* fused unary + mul */
|
||||
|
||||
template <float (*op)(float)>
|
||||
static void ggml_cuda_op_unary_mul_impl(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node) {
|
||||
// unary_node: UNARY op applied to unary_node->src[0]
|
||||
// mul_node: MUL(a, b) where one of a/b is unary_node
|
||||
// Output goes to mul_node->data
|
||||
|
||||
const ggml_tensor * unary_src = unary_node->src[0]; // input to the unary op
|
||||
const ggml_tensor * other_src = (mul_node->src[0] == unary_node) ? mul_node->src[1] : mul_node->src[0];
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_1(unary_src));
|
||||
GGML_ASSERT(unary_src->nb[0] == ggml_element_size(unary_src));
|
||||
GGML_ASSERT(ggml_is_contiguous_1(other_src));
|
||||
GGML_ASSERT(other_src->nb[0] == ggml_element_size(other_src));
|
||||
GGML_ASSERT(ggml_are_same_shape(unary_src, other_src));
|
||||
|
||||
GGML_ASSERT(unary_src->type == GGML_TYPE_F32 || unary_src->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(unary_src->type == other_src->type);
|
||||
GGML_ASSERT(unary_src->type == mul_node->type);
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
const int64_t k = ggml_nelements(mul_node);
|
||||
const int64_t nc = unary_src->ne[0];
|
||||
const int64_t unary_stride = unary_src->nb[1];
|
||||
const int64_t other_stride = other_src->nb[1];
|
||||
|
||||
if (unary_src->type == GGML_TYPE_F16) {
|
||||
unary_gated_cuda<op>((const half *) unary_src->data, (const half *) other_src->data,
|
||||
(half *) mul_node->data, k, nc,
|
||||
unary_stride / sizeof(half), other_stride / sizeof(half), stream);
|
||||
} else {
|
||||
unary_gated_cuda<op>((const float *) unary_src->data, (const float *) other_src->data,
|
||||
(float *) mul_node->data, k, nc,
|
||||
unary_stride / sizeof(float), other_stride / sizeof(float), stream);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node) {
|
||||
switch (ggml_get_unary_op(unary_node)) {
|
||||
case GGML_UNARY_OP_SILU:
|
||||
ggml_cuda_op_unary_mul_impl<op_silu>(ctx, unary_node, mul_node);
|
||||
break;
|
||||
case GGML_UNARY_OP_SIGMOID:
|
||||
ggml_cuda_op_unary_mul_impl<op_sigmoid>(ctx, unary_node, mul_node);
|
||||
break;
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
ggml_cuda_op_unary_mul_impl<op_softplus>(ctx, unary_node, mul_node);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported unary op for fused unary+mul");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -89,6 +89,8 @@ void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||
|
||||
void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node);
|
||||
|
||||
__device__ __forceinline__ float ggml_cuda_op_silu_single(float x) {
|
||||
return x / (1.0f + expf(-x));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7663,6 +7663,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}));
|
||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {2 * d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}));
|
||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv, d_inner, 4, 1}, {d_conv, d_inner, 1, 1}));
|
||||
// long token (n_t > 32, exercises the long_token kernel path)
|
||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv - 1 + 64, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}));
|
||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv - 1 + 64, d_inner, 4, 1}, {d_conv, d_inner, 1, 1}));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue