SOLVE_TRI CUDA kernel for small matrices (#17457)
This commit is contained in:
parent
efaaccdd69
commit
cd0e3a7a3b
|
|
@ -53,6 +53,7 @@
|
||||||
#include "ggml-cuda/set.cuh"
|
#include "ggml-cuda/set.cuh"
|
||||||
#include "ggml-cuda/set-rows.cuh"
|
#include "ggml-cuda/set-rows.cuh"
|
||||||
#include "ggml-cuda/pad_reflect_1d.cuh"
|
#include "ggml-cuda/pad_reflect_1d.cuh"
|
||||||
|
#include "ggml-cuda/solve_tri.cuh"
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
@ -2717,6 +2718,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||||
case GGML_OP_OPT_STEP_SGD:
|
case GGML_OP_OPT_STEP_SGD:
|
||||||
ggml_cuda_opt_step_sgd(ctx, dst);
|
ggml_cuda_opt_step_sgd(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_SOLVE_TRI:
|
||||||
|
ggml_cuda_op_solve_tri(ctx, dst);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
@ -4255,6 +4259,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
case GGML_OP_OPT_STEP_ADAMW:
|
case GGML_OP_OPT_STEP_ADAMW:
|
||||||
case GGML_OP_OPT_STEP_SGD:
|
case GGML_OP_OPT_STEP_SGD:
|
||||||
return true;
|
return true;
|
||||||
|
case GGML_OP_SOLVE_TRI:
|
||||||
|
return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,203 @@
|
||||||
|
#include "common.cuh"
|
||||||
|
#include "ggml.h"
|
||||||
|
#include "solve_tri.cuh"
|
||||||
|
|
||||||
|
#define MAX_N_FAST 64
|
||||||
|
#define MAX_K_FAST 32
|
||||||
|
|
||||||
|
// ======================
|
||||||
|
// Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction
|
||||||
|
// ======================
|
||||||
|
// When ncols_template == 0 the bounds for the loops in this function are not
|
||||||
|
// known and can't be unrolled. As we want to keep pragma unroll for all other
|
||||||
|
// cases we supress the clang transformation warning here.
|
||||||
|
#ifdef __clang__
|
||||||
|
# pragma clang diagnostic push
|
||||||
|
# pragma clang diagnostic ignored "-Wpass-failed"
|
||||||
|
#endif // __clang__
|
||||||
|
template <int n_template, int k_template>
|
||||||
|
static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
|
||||||
|
const float * __restrict__ B,
|
||||||
|
float * __restrict__ X,
|
||||||
|
const uint3 ne02,
|
||||||
|
const size_t nb02,
|
||||||
|
const size_t nb03,
|
||||||
|
const size_t nb12,
|
||||||
|
const size_t nb13,
|
||||||
|
const size_t nb2,
|
||||||
|
const size_t nb3,
|
||||||
|
const int n_arg,
|
||||||
|
const int k_arg) {
|
||||||
|
const int n = n_template == 0 ? n_arg : n_template;
|
||||||
|
const int k = k_template == 0 ? k_arg : k_template;
|
||||||
|
|
||||||
|
const int batch_idx = blockIdx.x;
|
||||||
|
const int lane = threadIdx.x;
|
||||||
|
const int col_idx = threadIdx.y;
|
||||||
|
|
||||||
|
if (col_idx >= k) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
|
||||||
|
const int64_t i02 = i02_i03.y;
|
||||||
|
const int64_t i03 = i02_i03.x;
|
||||||
|
|
||||||
|
const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
|
||||||
|
const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
|
||||||
|
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
|
||||||
|
|
||||||
|
__shared__ float sA[MAX_N_FAST * MAX_N_FAST];
|
||||||
|
__shared__ float sXt[MAX_N_FAST * (MAX_K_FAST + 1)];
|
||||||
|
|
||||||
|
const int offset = threadIdx.x + threadIdx.y * blockDim.x;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < n * n; i += k * WARP_SIZE) {
|
||||||
|
int i0 = i + offset;
|
||||||
|
if (i0 < n * n) {
|
||||||
|
sA[i0] = A_batch[i0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const int rows_per_warp = (n + WARP_SIZE - 1) / WARP_SIZE;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < rows_per_warp; i++) {
|
||||||
|
const int i0 = lane + i * WARP_SIZE;
|
||||||
|
if (i0 < n) {
|
||||||
|
sXt[col_idx * n + i0] = B_batch[i0 * k + col_idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int row = 0; row < n; ++row) {
|
||||||
|
float sum = 0.0f;
|
||||||
|
|
||||||
|
{
|
||||||
|
int j = lane;
|
||||||
|
if (j < row) {
|
||||||
|
sum += sA[row * n + j] * sXt[col_idx * n + j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (row >= WARP_SIZE) {
|
||||||
|
int j = WARP_SIZE + lane;
|
||||||
|
if (j < row) {
|
||||||
|
sum += sA[row * n + j] * sXt[col_idx * n + j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sum = warp_reduce_sum(sum);
|
||||||
|
|
||||||
|
if (lane == 0) {
|
||||||
|
const float b_val = sXt[col_idx * n + row];
|
||||||
|
const float a_diag = sA[row * n + row];
|
||||||
|
// no safeguards for division by zero because that indicates corrupt
|
||||||
|
// data anyway
|
||||||
|
sXt[col_idx * n + row] = (b_val - sum) / a_diag;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < rows_per_warp; i++) {
|
||||||
|
const int i0 = lane + i * WARP_SIZE;
|
||||||
|
if (i0 < n) {
|
||||||
|
X_batch[i0 * k + col_idx] = sXt[col_idx * n + i0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#ifdef __clang__
|
||||||
|
# pragma clang diagnostic pop
|
||||||
|
#endif // __clang__
|
||||||
|
|
||||||
|
static void solve_tri_f32_cuda(const float * A,
|
||||||
|
const float * B,
|
||||||
|
float * X,
|
||||||
|
int n,
|
||||||
|
int k,
|
||||||
|
int64_t ne02,
|
||||||
|
int64_t ne03,
|
||||||
|
size_t nb02,
|
||||||
|
size_t nb03,
|
||||||
|
size_t nb12,
|
||||||
|
size_t nb13,
|
||||||
|
size_t nb2,
|
||||||
|
size_t nb3,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
|
||||||
|
dim3 threads(WARP_SIZE, k);
|
||||||
|
dim3 grid(ne02 * ne03);
|
||||||
|
if (n == 64) {
|
||||||
|
switch (k) {
|
||||||
|
case 32:
|
||||||
|
solve_tri_f32_fast<64, 32>
|
||||||
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||||||
|
break;
|
||||||
|
case 16:
|
||||||
|
solve_tri_f32_fast<64, 16>
|
||||||
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||||||
|
break;
|
||||||
|
case 14:
|
||||||
|
solve_tri_f32_fast<64, 14>
|
||||||
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||||||
|
break;
|
||||||
|
case 12:
|
||||||
|
solve_tri_f32_fast<64, 12>
|
||||||
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||||||
|
break;
|
||||||
|
case 10:
|
||||||
|
solve_tri_f32_fast<64, 10>
|
||||||
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||||||
|
break;
|
||||||
|
case 8:
|
||||||
|
solve_tri_f32_fast<64, 8>
|
||||||
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||||||
|
break;
|
||||||
|
case 6:
|
||||||
|
solve_tri_f32_fast<64, 6>
|
||||||
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
solve_tri_f32_fast<64, 4>
|
||||||
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
solve_tri_f32_fast<64, 2>
|
||||||
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||||||
|
break;
|
||||||
|
case 1:
|
||||||
|
solve_tri_f32_fast<64, 1>
|
||||||
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
solve_tri_f32_fast<0, 0>
|
||||||
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
|
||||||
|
}
|
||||||
|
} else { // run general case
|
||||||
|
solve_tri_f32_fast<0, 0>
|
||||||
|
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * src0 = dst->src[0]; // A (triangular n x x matrix)
|
||||||
|
const ggml_tensor * src1 = dst->src[1]; // B (right hand side of n x k equation columns)
|
||||||
|
|
||||||
|
ggml_is_contiguous(src0);
|
||||||
|
ggml_is_contiguous(src1);
|
||||||
|
|
||||||
|
const int64_t n = src0->ne[0];
|
||||||
|
const int64_t k = src1->ne[0];
|
||||||
|
|
||||||
|
GGML_ASSERT(n <= 64);
|
||||||
|
GGML_ASSERT(k <= 32);
|
||||||
|
|
||||||
|
solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k, src0->ne[2],
|
||||||
|
src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
|
||||||
|
src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
|
||||||
|
dst->nb[3] / sizeof(float), ctx.stream());
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
#include "common.cuh"
|
||||||
|
|
||||||
|
void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
@ -7935,6 +7935,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3}));
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3}));
|
||||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, 2*16416));
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, 2*16416));
|
||||||
|
|
||||||
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 2 }, { 6, 64, 4, 2 }));
|
||||||
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 8, 128, 4, 1 }));
|
||||||
|
|
||||||
for (int bs : {1, 2, 3, 4, 5, 8, 512}) {
|
for (int bs : {1, 2, 3, 4, 5, 8, 512}) {
|
||||||
for (ggml_type type_a : all_types) {
|
for (ggml_type type_a : all_types) {
|
||||||
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue