SOLVE_TRI extension to more dimensions (#17793)
* Extended TRI * Fix whitespace * chore: update webui build output * Just use cuBLAS for everything... * Merge both versions * Remove incorrect imports causing failures for CI * Still failing... remove all direct cublas imports and rely on common imports from "common.cuh" * Defines for hipBlas * Aaaand MUSA defines... * I hate this job... * Stupid typo... * Update ggml/src/ggml-cuda/solve_tri.cu Co-authored-by: Johannes Gäßler <johannesg@5d6.de> --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
parent
c6f6e4f96a
commit
53ecd4fdb9
|
|
@ -4630,9 +4630,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
case GGML_OP_CUMSUM:
|
case GGML_OP_CUMSUM:
|
||||||
case GGML_OP_TRI:
|
case GGML_OP_TRI:
|
||||||
case GGML_OP_DIAG:
|
case GGML_OP_DIAG:
|
||||||
return true;
|
|
||||||
case GGML_OP_SOLVE_TRI:
|
case GGML_OP_SOLVE_TRI:
|
||||||
return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32;
|
return true;
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,80 @@
|
||||||
#include "solve_tri.cuh"
|
#include "solve_tri.cuh"
|
||||||
|
|
||||||
#define MAX_N_FAST 64
|
#define MAX_N_FAST 64
|
||||||
|
#define MAX_K_FAST 32
|
||||||
|
|
||||||
|
static __global__ void get_batch_pointers(const float * A,
|
||||||
|
float * X,
|
||||||
|
const float ** A_ptrs,
|
||||||
|
float ** X_ptrs,
|
||||||
|
int64_t ne02,
|
||||||
|
int64_t total_batches,
|
||||||
|
size_t s02,
|
||||||
|
size_t s03,
|
||||||
|
size_t s2,
|
||||||
|
size_t s3) {
|
||||||
|
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (idx >= total_batches) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t i3 = idx / ne02;
|
||||||
|
const int64_t i2 = idx % ne02;
|
||||||
|
|
||||||
|
A_ptrs[idx] = A + i3 * s03 + i2 * s02;
|
||||||
|
X_ptrs[idx] = X + i3 * s3 + i2 * s2;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void solve_tri_f32_cublas(ggml_backend_cuda_context & ctx,
|
||||||
|
const float * A,
|
||||||
|
const float * B,
|
||||||
|
float * X,
|
||||||
|
int n,
|
||||||
|
int k,
|
||||||
|
int64_t ne02,
|
||||||
|
int64_t ne03,
|
||||||
|
size_t s02,
|
||||||
|
size_t s03,
|
||||||
|
size_t s12,
|
||||||
|
size_t s13,
|
||||||
|
size_t s2,
|
||||||
|
size_t s3,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
const float alpha = 1.0f;
|
||||||
|
const int64_t total_batches = ne02 * ne03;
|
||||||
|
if (total_batches == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bulk copy B -> X (contiguous tensors)
|
||||||
|
if (X != B) {
|
||||||
|
const int64_t total_elements_BX = n * k * total_batches;
|
||||||
|
CUDA_CHECK(cudaMemcpyAsync(X, B, total_elements_BX * sizeof(float), cudaMemcpyDeviceToDevice, stream));
|
||||||
|
}
|
||||||
|
|
||||||
|
const int id = ggml_cuda_get_device();
|
||||||
|
|
||||||
|
ggml_cuda_pool_alloc<const float *> A_ptrs_alloc(ctx.pool(id), total_batches);
|
||||||
|
ggml_cuda_pool_alloc<float *> X_ptrs_alloc(ctx.pool(id), total_batches);
|
||||||
|
|
||||||
|
const float ** A_ptrs_dev = A_ptrs_alloc.get();
|
||||||
|
float ** X_ptrs_dev = X_ptrs_alloc.get();
|
||||||
|
|
||||||
|
get_batch_pointers<<<(total_batches + 255) / 256, 256, 0, stream>>>(A, X, A_ptrs_dev, X_ptrs_dev, ne02,
|
||||||
|
total_batches, s02, s03, s2, s3);
|
||||||
|
|
||||||
|
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
|
||||||
|
|
||||||
|
// Yes, this is necessary, without this we get RMSE errors
|
||||||
|
CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_DEFAULT_MATH));
|
||||||
|
CUBLAS_CHECK(cublasStrsmBatched(ctx.cublas_handle(id), CUBLAS_SIDE_RIGHT, CUBLAS_FILL_MODE_UPPER, CUBLAS_OP_N,
|
||||||
|
CUBLAS_DIAG_NON_UNIT, k, n, &alpha, A_ptrs_dev, n, X_ptrs_dev, k, total_batches));
|
||||||
|
|
||||||
|
// revert to standard mode from common.cuh
|
||||||
|
CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_TF32_TENSOR_OP_MATH));
|
||||||
|
|
||||||
|
GGML_UNUSED_VARS(s12, s13);
|
||||||
|
}
|
||||||
|
|
||||||
// ======================
|
// ======================
|
||||||
// Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction
|
// Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction
|
||||||
|
|
@ -176,20 +250,26 @@ static void solve_tri_f32_cuda(const float * A,
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * src0 = dst->src[0]; // A (triangular n x x matrix)
|
const ggml_tensor * src0 = dst->src[0]; // A (n×n, lower triangular)
|
||||||
const ggml_tensor * src1 = dst->src[1]; // B (right hand side of n x k equation columns)
|
const ggml_tensor * src1 = dst->src[1]; // B (n×k)
|
||||||
|
|
||||||
ggml_is_contiguous(src0);
|
ggml_is_contiguous(src0);
|
||||||
ggml_is_contiguous(src1);
|
ggml_is_contiguous(src1);
|
||||||
|
|
||||||
const int64_t n = src0->ne[0];
|
const int64_t n = src0->ne[0];
|
||||||
const int64_t k = src1->ne[0];
|
const int64_t k = src1->ne[0];
|
||||||
|
const int64_t ne02 = src0->ne[2];
|
||||||
|
const int64_t ne03 = src0->ne[3];
|
||||||
|
|
||||||
GGML_ASSERT(n <= 64);
|
if (n <= MAX_N_FAST && k <= MAX_K_FAST) {
|
||||||
GGML_ASSERT(k <= 32);
|
solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k,
|
||||||
|
src0->ne[2], src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
|
||||||
solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k, src0->ne[2],
|
src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
|
||||||
src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
|
dst->nb[3] / sizeof(float), ctx.stream());
|
||||||
|
} else {
|
||||||
|
solve_tri_f32_cublas(ctx, (const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k,
|
||||||
|
ne02, ne03, src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
|
||||||
src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
|
src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
|
||||||
dst->nb[3] / sizeof(float), ctx.stream());
|
dst->nb[3] / sizeof(float), ctx.stream());
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,9 @@
|
||||||
#define CUDA_R_16F HIPBLAS_R_16F
|
#define CUDA_R_16F HIPBLAS_R_16F
|
||||||
#define CUDA_R_16BF HIPBLAS_R_16B
|
#define CUDA_R_16BF HIPBLAS_R_16B
|
||||||
#define CUDA_R_32F HIPBLAS_R_32F
|
#define CUDA_R_32F HIPBLAS_R_32F
|
||||||
|
#define CUBLAS_SIDE_RIGHT HIPBLAS_SIDE_RIGHT
|
||||||
|
#define CUBLAS_FILL_MODE_UPPER HIPBLAS_FILL_MODE_UPPER
|
||||||
|
#define CUBLAS_DIAG_NON_UNIT HIPBLAS_DIAG_NON_UNIT
|
||||||
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported
|
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported
|
||||||
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended
|
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended
|
||||||
#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned
|
#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned
|
||||||
|
|
@ -30,6 +33,7 @@
|
||||||
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
||||||
#define __all_sync(mask, var) __all(var)
|
#define __all_sync(mask, var) __all(var)
|
||||||
#define __any_sync(mask, var) __any(var)
|
#define __any_sync(mask, var) __any(var)
|
||||||
|
#define cublasStrsmBatched hipblasStrsmBatched
|
||||||
#define cublasCreate hipblasCreate
|
#define cublasCreate hipblasCreate
|
||||||
#define cublasDestroy hipblasDestroy
|
#define cublasDestroy hipblasDestroy
|
||||||
#define cublasGemmEx hipblasGemmEx
|
#define cublasGemmEx hipblasGemmEx
|
||||||
|
|
|
||||||
|
|
@ -12,11 +12,16 @@
|
||||||
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
|
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
|
||||||
#define CUBLAS_OP_N MUBLAS_OP_N
|
#define CUBLAS_OP_N MUBLAS_OP_N
|
||||||
#define CUBLAS_OP_T MUBLAS_OP_T
|
#define CUBLAS_OP_T MUBLAS_OP_T
|
||||||
|
#define CUBLAS_DEFAULT_MATH MUBLAS_DEFAULT_MATH
|
||||||
|
#define CUBLAS_SIDE_RIGHT MUBLAS_SIDE_RIGHT
|
||||||
|
#define CUBLAS_FILL_MODE_UPPER MUBLAS_FILL_MODE_UPPER
|
||||||
|
#define CUBLAS_DIAG_NON_UNIT MUBLAS_DIAG_NON_UNIT
|
||||||
#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
|
#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
|
||||||
#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_TENSOR_OP_MATH
|
#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_TENSOR_OP_MATH
|
||||||
#define CUDA_R_16F MUSA_R_16F
|
#define CUDA_R_16F MUSA_R_16F
|
||||||
#define CUDA_R_16BF MUSA_R_16BF
|
#define CUDA_R_16BF MUSA_R_16BF
|
||||||
#define CUDA_R_32F MUSA_R_32F
|
#define CUDA_R_32F MUSA_R_32F
|
||||||
|
#define cublasStrsmBatched mublasStrsmBatched
|
||||||
#define cublasComputeType_t cudaDataType_t
|
#define cublasComputeType_t cudaDataType_t
|
||||||
#define cublasCreate mublasCreate
|
#define cublasCreate mublasCreate
|
||||||
#define cublasDestroy mublasDestroy
|
#define cublasDestroy mublasDestroy
|
||||||
|
|
|
||||||
|
|
@ -7861,9 +7861,24 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 30, 30, 7, 1 }, { 8, 30, 7, 1 }));
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 30, 30, 7, 1 }, { 8, 30, 7, 1 }));
|
||||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 42, 42, 5, 2 }, { 10, 42, 5, 2 }));
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 42, 42, 5, 2 }, { 10, 42, 5, 2 }));
|
||||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 2, 2 }, { 10, 64, 2, 2 }));
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 2, 2 }, { 10, 64, 2, 2 }));
|
||||||
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 2, 2 }, { 64, 64, 2, 2 }));
|
||||||
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 79, 79, 5, 3 }, { 417, 79, 5, 3 }));
|
||||||
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 2 }, { 32, 128, 4, 2 }));
|
||||||
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 2, 8 }, { 80, 80, 2, 8 }));
|
||||||
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 2, 8 }, { 79, 80, 2, 8 }));
|
||||||
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 2, 8 }, { 81, 80, 2, 8 }));
|
||||||
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 8, 8 }, { 80, 80, 8, 8 }));
|
||||||
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 8, 8 }, { 79, 80, 8, 8 }));
|
||||||
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 8, 8 }, { 81, 80, 8, 8 }));
|
||||||
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 84, 84, 4, 4 }, { 32, 84, 4, 4 }));
|
||||||
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 95, 95, 8, 8 }, { 40, 95, 8, 8 }));
|
||||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 100, 100, 4, 4 }, { 41, 100, 4, 4 }));
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 100, 100, 4, 4 }, { 41, 100, 4, 4 }));
|
||||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 4 }, { 31, 128, 4, 4 }));
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 4 }, { 31, 128, 4, 4 }));
|
||||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 300, 64, 4, 4 }));
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 4 }, { 32, 128, 4, 4 }));
|
||||||
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 3, 4 }, { 32, 128, 3, 4 }));
|
||||||
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 32, 128, 4, 1 }));
|
||||||
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 200, 64, 4, 4 }));
|
||||||
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 384, 64, 4, 4 }));
|
||||||
|
|
||||||
for (bool v : {false, true}) {
|
for (bool v : {false, true}) {
|
||||||
for (bool circular : {false, true}) {
|
for (bool circular : {false, true}) {
|
||||||
|
|
@ -8064,12 +8079,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3}));
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3}));
|
||||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, 2*16416));
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, 2*16416));
|
||||||
|
|
||||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 2 }, { 6, 64, 4, 2 }));
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 32, 64, 4, 4 }));
|
||||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 8, 128, 4, 1 }));
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 2 }, { 32, 128, 4, 2 }));
|
||||||
// qwen3next with CHUNK_SIZE 64
|
// qwen3next with CHUNK_SIZE 64
|
||||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 8, 32 }, { 64, 64, 8, 32 }));
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 8, 32 }, { 64, 64, 8, 32 }));
|
||||||
// qwen3next with CHUNK_SIZE 128
|
// qwen3next with CHUNK_SIZE 128
|
||||||
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 32 }, { 128, 128, 4, 32 }));
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 32 }, { 128, 128, 4, 32 }));
|
||||||
|
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 256, 256, 4, 2 }, { 128, 256, 4, 2 }));
|
||||||
|
|
||||||
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_F32, { 256, 256, 4, 4 }));
|
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_F32, { 256, 256, 4, 4 }));
|
||||||
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG, GGML_TYPE_F32, { 1024, 1024, 8, 4 }));
|
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG, GGML_TYPE_F32, { 1024, 1024, 8, 4 }));
|
||||||
|
|
|
||||||
Binary file not shown.
Loading…
Reference in New Issue