276 lines
11 KiB
Plaintext
276 lines
11 KiB
Plaintext
#include "common.cuh"
|
||
#include "ggml.h"
|
||
#include "solve_tri.cuh"
|
||
|
||
#define MAX_N_FAST 64
|
||
#define MAX_K_FAST 32
|
||
|
||
static __global__ void get_batch_pointers(const float * A,
|
||
float * X,
|
||
const float ** A_ptrs,
|
||
float ** X_ptrs,
|
||
int64_t ne02,
|
||
int64_t total_batches,
|
||
size_t s02,
|
||
size_t s03,
|
||
size_t s2,
|
||
size_t s3) {
|
||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||
if (idx >= total_batches) {
|
||
return;
|
||
}
|
||
|
||
const int64_t i3 = idx / ne02;
|
||
const int64_t i2 = idx % ne02;
|
||
|
||
A_ptrs[idx] = A + i3 * s03 + i2 * s02;
|
||
X_ptrs[idx] = X + i3 * s3 + i2 * s2;
|
||
}
|
||
|
||
static void solve_tri_f32_cublas(ggml_backend_cuda_context & ctx,
|
||
const float * A,
|
||
const float * B,
|
||
float * X,
|
||
int n,
|
||
int k,
|
||
int64_t ne02,
|
||
int64_t ne03,
|
||
size_t s02,
|
||
size_t s03,
|
||
size_t s12,
|
||
size_t s13,
|
||
size_t s2,
|
||
size_t s3,
|
||
cudaStream_t stream) {
|
||
const float alpha = 1.0f;
|
||
const int64_t total_batches = ne02 * ne03;
|
||
if (total_batches == 0) {
|
||
return;
|
||
}
|
||
|
||
// Bulk copy B -> X (contiguous tensors)
|
||
if (X != B) {
|
||
const int64_t total_elements_BX = n * k * total_batches;
|
||
CUDA_CHECK(cudaMemcpyAsync(X, B, total_elements_BX * sizeof(float), cudaMemcpyDeviceToDevice, stream));
|
||
}
|
||
|
||
const int id = ggml_cuda_get_device();
|
||
|
||
ggml_cuda_pool_alloc<const float *> A_ptrs_alloc(ctx.pool(id), total_batches);
|
||
ggml_cuda_pool_alloc<float *> X_ptrs_alloc(ctx.pool(id), total_batches);
|
||
|
||
const float ** A_ptrs_dev = A_ptrs_alloc.get();
|
||
float ** X_ptrs_dev = X_ptrs_alloc.get();
|
||
|
||
get_batch_pointers<<<(total_batches + 255) / 256, 256, 0, stream>>>(A, X, A_ptrs_dev, X_ptrs_dev, ne02,
|
||
total_batches, s02, s03, s2, s3);
|
||
|
||
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
|
||
|
||
// Yes, this is necessary, without this we get RMSE errors
|
||
CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_DEFAULT_MATH));
|
||
CUBLAS_CHECK(cublasStrsmBatched(ctx.cublas_handle(id), CUBLAS_SIDE_RIGHT, CUBLAS_FILL_MODE_UPPER, CUBLAS_OP_N,
|
||
CUBLAS_DIAG_NON_UNIT, k, n, &alpha, A_ptrs_dev, n, X_ptrs_dev, k, total_batches));
|
||
|
||
// revert to standard mode from common.cuh
|
||
CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_TF32_TENSOR_OP_MATH));
|
||
|
||
GGML_UNUSED_VARS(s12, s13);
|
||
}
|
||
|
||
// ======================
|
||
// Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction
|
||
// ======================
|
||
// When ncols_template == 0 the bounds for the loops in this function are not
|
||
// known and can't be unrolled. As we want to keep pragma unroll for all other
|
||
// cases we supress the clang transformation warning here.
|
||
#ifdef __clang__
|
||
# pragma clang diagnostic push
|
||
# pragma clang diagnostic ignored "-Wpass-failed"
|
||
#endif // __clang__
|
||
template <int n_template, int k_template>
|
||
static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
|
||
const float * __restrict__ B,
|
||
float * __restrict__ X,
|
||
const uint3 ne02,
|
||
const size_t nb02,
|
||
const size_t nb03,
|
||
const size_t nb12,
|
||
const size_t nb13,
|
||
const size_t nb2,
|
||
const size_t nb3,
|
||
const int n_arg,
|
||
const int k_arg) {
|
||
const int n = n_template == 0 ? n_arg : n_template;
|
||
const int k = k_template == 0 ? k_arg : k_template;
|
||
|
||
const int batch_idx = blockIdx.x;
|
||
const int lane = threadIdx.x;
|
||
const int col_idx = threadIdx.y;
|
||
|
||
if (col_idx >= k) {
|
||
return;
|
||
}
|
||
|
||
const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
|
||
const int64_t i02 = i02_i03.y;
|
||
const int64_t i03 = i02_i03.x;
|
||
|
||
const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
|
||
const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
|
||
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
|
||
|
||
__shared__ float sA[MAX_N_FAST * MAX_N_FAST];
|
||
|
||
const int offset = threadIdx.x + threadIdx.y * blockDim.x;
|
||
|
||
#pragma unroll
|
||
for (int i = 0; i < n * n; i += k * WARP_SIZE) {
|
||
const int i0 = i + offset;
|
||
if (i0 < n * n) {
|
||
sA[i0] = A_batch[i0];
|
||
}
|
||
}
|
||
|
||
__syncthreads();
|
||
|
||
float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f;
|
||
float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f;
|
||
|
||
const int half = WARP_SIZE;
|
||
const int nrows_low = (n < half) ? n : half;
|
||
|
||
#pragma unroll
|
||
for (int row = 0; row < nrows_low; ++row) {
|
||
float sum = 0.0f;
|
||
if (lane < row) {
|
||
sum += sA[row * n + lane] * x_low;
|
||
}
|
||
sum = warp_reduce_sum(sum);
|
||
|
||
if (lane == row) {
|
||
x_low = (x_low - sum) / sA[row * n + row];
|
||
}
|
||
}
|
||
|
||
#pragma unroll
|
||
for (int row = half; row < n; ++row) {
|
||
float sum = sA[row * n + lane] * x_low;
|
||
const int j = half + lane;
|
||
if (j < row) {
|
||
sum += sA[row * n + j] * x_high;
|
||
}
|
||
sum = warp_reduce_sum(sum);
|
||
|
||
if (lane == row - half) {
|
||
x_high = (x_high - sum) / sA[row * n + row];
|
||
}
|
||
}
|
||
|
||
#pragma unroll
|
||
for (int rr = 0; rr < 2; ++rr) {
|
||
const int row = rr * WARP_SIZE + lane;
|
||
if (row < n) {
|
||
const float val = (row < half) ? x_low : x_high;
|
||
X_batch[row * k + col_idx] = val;
|
||
}
|
||
}
|
||
}
|
||
#ifdef __clang__
|
||
# pragma clang diagnostic pop
|
||
#endif // __clang__
|
||
|
||
static void solve_tri_f32_cuda(const float * A,
|
||
const float * B,
|
||
float * X,
|
||
int n,
|
||
int k,
|
||
int64_t ne02,
|
||
int64_t ne03,
|
||
size_t nb02,
|
||
size_t nb03,
|
||
size_t nb12,
|
||
size_t nb13,
|
||
size_t nb2,
|
||
size_t nb3,
|
||
cudaStream_t stream) {
|
||
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
|
||
dim3 threads(WARP_SIZE, k);
|
||
dim3 grid(ne02 * ne03);
|
||
if (n == 64) {
|
||
switch (k) {
|
||
case 32:
|
||
solve_tri_f32_fast<64, 32>
|
||
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||
break;
|
||
case 16:
|
||
solve_tri_f32_fast<64, 16>
|
||
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||
break;
|
||
case 14:
|
||
solve_tri_f32_fast<64, 14>
|
||
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||
break;
|
||
case 12:
|
||
solve_tri_f32_fast<64, 12>
|
||
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||
break;
|
||
case 10:
|
||
solve_tri_f32_fast<64, 10>
|
||
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||
break;
|
||
case 8:
|
||
solve_tri_f32_fast<64, 8>
|
||
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||
break;
|
||
case 6:
|
||
solve_tri_f32_fast<64, 6>
|
||
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||
break;
|
||
case 4:
|
||
solve_tri_f32_fast<64, 4>
|
||
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||
break;
|
||
case 2:
|
||
solve_tri_f32_fast<64, 2>
|
||
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||
break;
|
||
case 1:
|
||
solve_tri_f32_fast<64, 1>
|
||
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
|
||
break;
|
||
default:
|
||
solve_tri_f32_fast<0, 0>
|
||
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
|
||
}
|
||
} else { // run general case
|
||
solve_tri_f32_fast<0, 0>
|
||
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
|
||
}
|
||
}
|
||
|
||
void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||
const ggml_tensor * src0 = dst->src[0]; // A (n×n, lower triangular)
|
||
const ggml_tensor * src1 = dst->src[1]; // B (n×k)
|
||
|
||
ggml_is_contiguous(src0);
|
||
ggml_is_contiguous(src1);
|
||
|
||
const int64_t n = src0->ne[0];
|
||
const int64_t k = src1->ne[0];
|
||
const int64_t ne02 = src0->ne[2];
|
||
const int64_t ne03 = src0->ne[3];
|
||
|
||
if (n <= MAX_N_FAST && k <= MAX_K_FAST) {
|
||
solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k,
|
||
src0->ne[2], src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
|
||
src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
|
||
dst->nb[3] / sizeof(float), ctx.stream());
|
||
} else {
|
||
solve_tri_f32_cublas(ctx, (const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k,
|
||
ne02, ne03, src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
|
||
src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
|
||
dst->nb[3] / sizeof(float), ctx.stream());
|
||
}
|
||
}
|