diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu index 6643f243b1..35d48f614c 100644 --- a/ggml/src/ggml-cuda/mmf.cu +++ b/ggml/src/ggml-cuda/mmf.cu @@ -2,6 +2,12 @@ #include "mmf.cuh" #include "mmid.cuh" +constexpr int mmf_rows_per_block = 32; +constexpr int mmf_rows_per_block_cdna = 64; + +static int get_mmf_rows_per_block(const int cc, const int warp_size) { + return warp_size; +} void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { GGML_ASSERT( src1->type == GGML_TYPE_F32); @@ -89,30 +95,60 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr ids_info_ptr = &ids_info; } + const int device = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[device].cc; + const int warp_size = ggml_cuda_info().devices[device].warp_size; + const int rows_per_block = get_mmf_rows_per_block(cc, warp_size); + + if (rows_per_block != mmf_rows_per_block && rows_per_block != mmf_rows_per_block_cdna) { + GGML_ABORT("unsupported rows_per_block: %i", rows_per_block); + } + switch (src0->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0->data; constexpr int vals_per_T = 1; - mul_mat_f_switch_cols_per_block( - src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, - ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, - ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); + if (rows_per_block == mmf_rows_per_block) { + mul_mat_f_switch_cols_per_block( + src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, + ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); + } else { + mul_mat_f_switch_cols_per_block( + src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, + ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); + } } break; case GGML_TYPE_F16: { const half2 * src0_d = (const half2 *) src0->data; constexpr int vals_per_T = 2; - mul_mat_f_switch_cols_per_block( - src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, - ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, - ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); + if (rows_per_block == mmf_rows_per_block) { + mul_mat_f_switch_cols_per_block( + src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, + ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); + } else { + mul_mat_f_switch_cols_per_block( + src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, + ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); + } } break; case GGML_TYPE_BF16: { const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data; constexpr int vals_per_T = 2; - mul_mat_f_switch_cols_per_block( - src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, - ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, - ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); + if (rows_per_block == mmf_rows_per_block) { + mul_mat_f_switch_cols_per_block( + src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, + ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); + } else { + mul_mat_f_switch_cols_per_block( + src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, + ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); + } } break; default: GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); @@ -140,7 +176,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const return false; } } - if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) { + if (src0_ne[1] % get_mmf_rows_per_block(cc, warp_size) != 0) { return false; } diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index e36730948f..e6e9d79aa1 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -6,8 +6,6 @@ using namespace ggml_cuda_mma; -#define MMF_ROWS_PER_BLOCK 32 - struct mmf_ids_data { const int32_t * ids_src_compact = nullptr; const int32_t * ids_dst_compact = nullptr; @@ -228,21 +226,30 @@ static __global__ void mul_mat_f( return; } - float sum = 0.0f; - static_assert(rows_per_block == warp_size, "need loop/check"); + float sum[rows_per_block/warp_size] = {0.0f}; + static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size."); #pragma unroll for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) { - const int i = i0 + threadIdx.x; +#pragma unroll + for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) { + const int i = i0 + i1*warp_size + threadIdx.x; - sum += buf_iw[j*kiw + i]; + sum[i1] += buf_iw[j*kiw + i]; + } } if constexpr (!has_ids) { - dst[j*stride_col_dst + row0 + threadIdx.x] = sum; +#pragma unroll + for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) { + dst[j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0]; + } } else { const int slot = (j < cols_per_block) ? slot_map[j] : -1; if (slot >= 0 && (col_base + j) < ncols_dst_total) { - dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum; +#pragma unroll + for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) { + dst[slot*stride_channel_dst + j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0]; + } } } } @@ -497,13 +504,16 @@ static __global__ void mul_mat_f_ids( return; } - float sum = 0.0f; - static_assert(rows_per_block == warp_size, "need loop/check"); + float sum[rows_per_block/warp_size] = {0.0f}; + static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size."); #pragma unroll for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) { - const int i = i0 + threadIdx.x; +#pragma unroll + for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) { + const int i = i0 + i1*warp_size + threadIdx.x; - sum += buf_iw[j*kiw + i]; + sum[i1] += buf_iw[j * kiw + i]; + } } const int global_j = col_base + j; @@ -513,7 +523,10 @@ static __global__ void mul_mat_f_ids( const int token = (int) qrm.x; if (token < ncols_dst_total) { const int slot = (int) qrm.y; - dst[slot*stride_channel_dst + token*stride_col_dst + row0 + threadIdx.x] = sum; +#pragma unroll + for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) { + dst[slot * stride_channel_dst + token * stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0]; + } } } } @@ -529,7 +542,7 @@ static __global__ void mul_mat_f_ids( #endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) } -template +template static inline void mul_mat_f_switch_ids( const T * x, const float * y, const int32_t * ids, float * dst, const int64_t ncols_x, const int64_t ncols_dst, const int64_t nchannels_dst, @@ -553,7 +566,7 @@ static inline void mul_mat_f_switch_ids( const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1); const uint3 nch_fd = init_fastdiv_values((uint32_t) nchannels_dst); - mul_mat_f_ids<<>> + mul_mat_f_ids<<>> (x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, @@ -564,19 +577,19 @@ static inline void mul_mat_f_switch_ids( dim3 block_nums_ids = block_nums; block_nums_ids.y *= col_tiles; - mul_mat_f<<>> + mul_mat_f<<>> (x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } else { - mul_mat_f<<>> + mul_mat_f<<>> (x, y, ids, dst, ncols_x, cols_per_block, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } } -template +template void mul_mat_f_cuda( const T * x, const float * y, const int32_t * ids, float * dst, const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst, @@ -614,7 +627,6 @@ void mul_mat_f_cuda( } } - constexpr int rows_per_block = MMF_ROWS_PER_BLOCK; const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4; const int nbytes_cols_per_block_pad = amd_wmma_available(cc) ? tile_B_16::I : tile_B_8::I; const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + 4) * 4; @@ -628,56 +640,56 @@ void mul_mat_f_cuda( switch (nwarps_best) { case 1: { - mul_mat_f_switch_ids( + mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, ids_data); } break; case 2: { - mul_mat_f_switch_ids( + mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, ids_data); } break; case 3: { - mul_mat_f_switch_ids( + mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, ids_data); } break; case 4: { - mul_mat_f_switch_ids( + mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, ids_data); } break; case 5: { - mul_mat_f_switch_ids( + mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, ids_data); } break; case 6: { - mul_mat_f_switch_ids( + mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, ids_data); } break; case 7: { - mul_mat_f_switch_ids( + mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, ids_data); } break; case 8: { - mul_mat_f_switch_ids( + mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, @@ -691,7 +703,7 @@ void mul_mat_f_cuda( GGML_UNUSED_VARS(nchannels_y); } -template +template static void mul_mat_f_switch_cols_per_block( const T * x, const float * y, const int32_t * ids, float * dst, const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst, @@ -708,82 +720,82 @@ static void mul_mat_f_switch_cols_per_block( switch (ncols_case) { case 1: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 2: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 3: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 4: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 5: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 6: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 7: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 8: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 9: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 10: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 11: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 12: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 13: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 14: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 15: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 16: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break;